Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,221 @@
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
import aiohttp.client_exceptions
|
||||
import aiohttp.http_exceptions
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.llms.custom_httpx.aiohttp_transport import (
|
||||
AiohttpResponseStream,
|
||||
LiteLLMAiohttpTransport,
|
||||
map_aiohttp_exceptions,
|
||||
)
|
||||
|
||||
|
||||
class MockAiohttpResponse:
|
||||
"""Mock aiohttp ClientResponse for testing"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status=200,
|
||||
headers=None,
|
||||
content_chunks=None,
|
||||
exception_to_raise=None,
|
||||
exception_at_chunk=None,
|
||||
):
|
||||
self.status = status
|
||||
self.headers = headers or {}
|
||||
self.content = MockContent(
|
||||
content_chunks, exception_to_raise, exception_at_chunk
|
||||
)
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
|
||||
class MockContent:
|
||||
"""Mock aiohttp response content for testing"""
|
||||
|
||||
def __init__(self, chunks=None, exception_to_raise=None, exception_at_chunk=None):
|
||||
self.chunks = chunks or [b"chunk1", b"chunk2", b"chunk3"]
|
||||
self.exception_to_raise = exception_to_raise
|
||||
self.exception_at_chunk = exception_at_chunk or (len(self.chunks) - 1)
|
||||
self.chunk_index = 0
|
||||
|
||||
async def iter_chunked(self, chunk_size):
|
||||
for i, chunk in enumerate(self.chunks):
|
||||
if self.exception_to_raise and i == self.exception_at_chunk:
|
||||
# Raise exception at specified chunk to simulate partial transfer
|
||||
raise self.exception_to_raise
|
||||
yield chunk
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiohttp_response_stream_normal_flow():
|
||||
"""Test normal flow of AiohttpResponseStream without exceptions"""
|
||||
mock_response = MockAiohttpResponse(content_chunks=[b"hello", b"world", b"test"])
|
||||
|
||||
stream = AiohttpResponseStream(mock_response) # type: ignore
|
||||
chunks = []
|
||||
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert chunks == [b"hello", b"world", b"test"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transfer_encoding_error_no_httpx_read_error():
|
||||
"""Test that TransferEncodingError doesn't get converted to httpx.ReadError"""
|
||||
import logging
|
||||
|
||||
# Create a TransferEncodingError wrapped in ClientPayloadError (like in real scenarios)
|
||||
transfer_error = aiohttp.http_exceptions.TransferEncodingError(
|
||||
message="400, message: Not enough data for satisfy transfer length header."
|
||||
)
|
||||
|
||||
# Wrap it in ClientPayloadError as aiohttp does
|
||||
client_payload_error = aiohttp.ClientPayloadError(
|
||||
"Response payload is not completed"
|
||||
)
|
||||
client_payload_error.__cause__ = transfer_error
|
||||
|
||||
mock_response = MockAiohttpResponse(
|
||||
content_chunks=[b"chunk1", b"chunk2", b"chunk3"],
|
||||
exception_to_raise=client_payload_error,
|
||||
exception_at_chunk=1, # Error occurs at chunk 1
|
||||
)
|
||||
|
||||
stream = AiohttpResponseStream(mock_response) # type: ignore
|
||||
received_chunks = []
|
||||
|
||||
# This should NOT raise httpx.ReadError or any other exception
|
||||
# It should handle the error gracefully and just return what was received
|
||||
async for chunk in stream:
|
||||
received_chunks.append(chunk)
|
||||
print(f"received_chunks: {received_chunks}")
|
||||
|
||||
# Should have received the first chunk before the error
|
||||
assert received_chunks == [b"chunk1"]
|
||||
assert len(received_chunks) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_payload_error_graceful_handling():
|
||||
"""Test that ClientPayloadError is handled gracefully without stacktrace"""
|
||||
# Create a ClientPayloadError directly
|
||||
client_error = aiohttp.client_exceptions.ClientPayloadError(
|
||||
"Response payload is not completed"
|
||||
)
|
||||
|
||||
mock_response = MockAiohttpResponse(
|
||||
content_chunks=[b"data1", b"data2", b"data3"],
|
||||
exception_to_raise=client_error,
|
||||
exception_at_chunk=2, # Error occurs at chunk 2
|
||||
)
|
||||
|
||||
stream = AiohttpResponseStream(mock_response) # type: ignore
|
||||
received_chunks = []
|
||||
|
||||
# This should handle the error gracefully without raising
|
||||
async for chunk in stream:
|
||||
received_chunks.append(chunk)
|
||||
|
||||
# Should have received chunks before the error
|
||||
assert received_chunks == [b"data1", b"data2"]
|
||||
assert len(received_chunks) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_aiohttp_exception_gets_mapped():
|
||||
"""Test that unknown aiohttp exceptions still get mapped to httpx exceptions"""
|
||||
# Create an aiohttp exception that's not specifically handled
|
||||
# Using InvalidURL which should map to httpx.InvalidURL
|
||||
invalid_url_error = aiohttp.InvalidURL("Invalid URL format")
|
||||
|
||||
mock_response = MockAiohttpResponse(
|
||||
content_chunks=[b"chunk1", b"chunk2"],
|
||||
exception_to_raise=invalid_url_error,
|
||||
exception_at_chunk=0, # Error occurs immediately
|
||||
)
|
||||
|
||||
stream = AiohttpResponseStream(mock_response) # type: ignore
|
||||
|
||||
# This should raise httpx.InvalidURL (mapped from aiohttp.InvalidURL)
|
||||
with pytest.raises(httpx.InvalidURL):
|
||||
async for chunk in stream:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_exception_gets_mapped():
|
||||
"""Test that aiohttp timeout exceptions get mapped to httpx timeout exceptions"""
|
||||
# Create an aiohttp timeout exception
|
||||
timeout_error = aiohttp.ServerTimeoutError("Server timeout")
|
||||
|
||||
mock_response = MockAiohttpResponse(
|
||||
content_chunks=[b"chunk1", b"chunk2"],
|
||||
exception_to_raise=timeout_error,
|
||||
exception_at_chunk=1, # Error occurs at chunk 1
|
||||
)
|
||||
|
||||
stream = AiohttpResponseStream(mock_response) # type: ignore
|
||||
received_chunks = []
|
||||
|
||||
# This should raise httpx.TimeoutException (mapped from aiohttp.ServerTimeoutError)
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
async for chunk in stream:
|
||||
received_chunks.append(chunk)
|
||||
|
||||
# Should have received the first chunk before the error
|
||||
assert received_chunks == [b"chunk1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_async_request_uses_env_proxy(monkeypatch):
|
||||
"""Aiohttp transport should honor HTTP(S)_PROXY env vars"""
|
||||
proxy_url = "http://proxy.local:3128"
|
||||
monkeypatch.setenv("HTTP_PROXY", proxy_url)
|
||||
monkeypatch.setenv("http_proxy", proxy_url)
|
||||
monkeypatch.setenv("HTTPS_PROXY", proxy_url)
|
||||
monkeypatch.setenv("https_proxy", proxy_url)
|
||||
monkeypatch.delenv("DISABLE_AIOHTTP_TRUST_ENV", raising=False)
|
||||
|
||||
captured = {}
|
||||
|
||||
class FakeSession:
|
||||
def request(self, *args, **kwargs):
|
||||
captured["proxy"] = kwargs.get("proxy")
|
||||
|
||||
class Resp:
|
||||
status = 200
|
||||
headers = {}
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
pass
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
class C:
|
||||
async def iter_chunked(self, size):
|
||||
yield b""
|
||||
|
||||
return C()
|
||||
|
||||
return Resp()
|
||||
|
||||
transport = LiteLLMAiohttpTransport(client=lambda: FakeSession())
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
await transport.handle_async_request(request)
|
||||
|
||||
assert captured["proxy"] == proxy_url
|
@@ -0,0 +1,185 @@
|
||||
import io
|
||||
import os
|
||||
import pathlib
|
||||
import ssl
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import certifi
|
||||
import httpx
|
||||
import pytest
|
||||
from aiohttp import ClientSession, TCPConnector
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.aiohttp_transport import LiteLLMAiohttpTransport
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, get_ssl_configuration
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ssl_security_level(monkeypatch):
|
||||
with patch.dict(os.environ, clear=True):
|
||||
# Set environment variable for SSL security level
|
||||
monkeypatch.setenv("SSL_SECURITY_LEVEL", "DEFAULT@SECLEVEL=1")
|
||||
|
||||
# Create async client with SSL verification disabled to isolate SSL context testing
|
||||
client = AsyncHTTPHandler()
|
||||
|
||||
# Get the transport (should be LiteLLMAiohttpTransport)
|
||||
transport = client.client._transport
|
||||
|
||||
# Get the aiohttp ClientSession
|
||||
client_session = transport._get_valid_client_session()
|
||||
|
||||
# Get the connector from the session
|
||||
connector = client_session.connector
|
||||
|
||||
# Get the SSL context from the connector
|
||||
ssl_context = connector._ssl
|
||||
print("ssl_context", ssl_context)
|
||||
|
||||
# Verify that the SSL context exists and has the correct cipher string
|
||||
assert isinstance(ssl_context, ssl.SSLContext)
|
||||
# Optionally, check the ciphers string if needed
|
||||
# assert "DEFAULT@SECLEVEL=1" in ssl_context.get_ciphers()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_force_ipv4_transport():
|
||||
"""Test transport creation with force_ipv4 enabled"""
|
||||
litellm.force_ipv4 = True
|
||||
litellm.disable_aiohttp_transport = True
|
||||
|
||||
transport = AsyncHTTPHandler._create_async_transport()
|
||||
|
||||
# Should get an AsyncHTTPTransport
|
||||
assert isinstance(transport, httpx.AsyncHTTPTransport)
|
||||
# Verify IPv4 configuration through a request
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
response = await client.get("http://example.com")
|
||||
assert response.status_code == 200
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ssl_context_transport():
|
||||
"""Test transport creation with SSL context"""
|
||||
# Create a test SSL context
|
||||
ssl_context = ssl.create_default_context()
|
||||
|
||||
transport = AsyncHTTPHandler._create_async_transport(ssl_context=ssl_context)
|
||||
assert transport is not None
|
||||
|
||||
if isinstance(transport, LiteLLMAiohttpTransport):
|
||||
# Get the client session and verify SSL context is passed through
|
||||
client_session = transport._get_valid_client_session()
|
||||
assert isinstance(client_session, ClientSession)
|
||||
assert isinstance(client_session.connector, TCPConnector)
|
||||
# Verify the connector has SSL context set by checking if it's using SSL
|
||||
assert client_session.connector._ssl is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiohttp_disabled_transport():
|
||||
"""Test transport creation with aiohttp disabled"""
|
||||
litellm.disable_aiohttp_transport = True
|
||||
litellm.force_ipv4 = False
|
||||
|
||||
transport = AsyncHTTPHandler._create_async_transport()
|
||||
|
||||
# Should get None when both aiohttp is disabled and force_ipv4 is False
|
||||
assert transport is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ssl_verification_with_aiohttp_transport():
|
||||
"""
|
||||
Test aiohttp respects ssl_verify=False
|
||||
|
||||
We validate that the ssl settings for a litellm transport match what a ssl verify=False aiohttp client would have.
|
||||
|
||||
"""
|
||||
import aiohttp
|
||||
|
||||
# Create a test SSL context
|
||||
litellm_async_client = AsyncHTTPHandler(ssl_verify=False)
|
||||
|
||||
transport_connector = (
|
||||
litellm_async_client.client._transport._get_valid_client_session().connector
|
||||
)
|
||||
print("transport_connector", transport_connector)
|
||||
print("transport_connector._ssl", transport_connector._ssl)
|
||||
|
||||
aiohttp_session = aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(verify_ssl=False)
|
||||
)
|
||||
print("aiohttp_session", aiohttp_session)
|
||||
print("aiohttp_session._ssl", aiohttp_session.connector._ssl)
|
||||
|
||||
# assert both litellm transport and aiohttp session have ssl_verify=False
|
||||
assert transport_connector._ssl == aiohttp_session.connector._ssl
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiohttp_transport_trust_env_setting(monkeypatch):
|
||||
"""Test that trust_env setting is properly configured in aiohttp transport"""
|
||||
# Test 1: Default trust_env behavior
|
||||
transport = AsyncHTTPHandler._create_aiohttp_transport()
|
||||
client_session = transport._get_valid_client_session()
|
||||
|
||||
# Default should be False (litellm.aiohttp_trust_env default)
|
||||
default_trust_env = getattr(litellm, 'aiohttp_trust_env', False)
|
||||
assert client_session._trust_env == default_trust_env
|
||||
|
||||
# Test 2: Environment variable override
|
||||
monkeypatch.setenv("AIOHTTP_TRUST_ENV", "True")
|
||||
transport_with_env = AsyncHTTPHandler._create_aiohttp_transport()
|
||||
client_session_with_env = transport_with_env._get_valid_client_session()
|
||||
|
||||
# Should be True when environment variable is set
|
||||
assert client_session_with_env._trust_env is True
|
||||
|
||||
# Test 3: Verify environment variable with False value
|
||||
monkeypatch.setenv("AIOHTTP_TRUST_ENV", "False")
|
||||
transport_with_false_env = AsyncHTTPHandler._create_aiohttp_transport()
|
||||
client_session_with_false_env = transport_with_false_env._get_valid_client_session()
|
||||
|
||||
# Should respect the litellm.aiohttp_trust_env setting when env var is False
|
||||
assert client_session_with_false_env._trust_env == default_trust_env
|
||||
|
||||
|
||||
def test_get_ssl_configuration():
|
||||
"""Test that get_ssl_configuration() returns a proper SSL context with certifi CA bundle
|
||||
when no environment variables are set."""
|
||||
with patch.dict(os.environ, clear=True):
|
||||
with patch('ssl.create_default_context') as mock_create_context:
|
||||
# Mock the return value
|
||||
mock_ssl_context = MagicMock(spec=ssl.SSLContext)
|
||||
mock_create_context.return_value = mock_ssl_context
|
||||
|
||||
# Call the static method
|
||||
result = get_ssl_configuration()
|
||||
|
||||
# Verify ssl.create_default_context was called with certifi's CA file
|
||||
expected_ca_file = certifi.where()
|
||||
mock_create_context.assert_called_once_with(cafile=expected_ca_file)
|
||||
|
||||
# Verify it returns the mocked SSL context
|
||||
assert result == mock_ssl_context
|
||||
|
||||
|
||||
def test_get_ssl_configuration_integration():
|
||||
"""Integration test that _get_ssl_context() returns a working SSL context"""
|
||||
# Call the static method without mocking
|
||||
ssl_context = get_ssl_configuration()
|
||||
|
||||
# Verify it returns an SSLContext instance
|
||||
assert isinstance(ssl_context, ssl.SSLContext)
|
||||
|
||||
# Verify it has basic SSL context properties
|
||||
assert ssl_context.protocol is not None
|
||||
assert ssl_context.verify_mode is not None
|
@@ -0,0 +1,77 @@
|
||||
import io
|
||||
import os
|
||||
import pathlib
|
||||
import ssl
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
|
||||
|
||||
def test_prepare_fake_stream_request():
|
||||
# Initialize the BaseLLMHTTPHandler
|
||||
handler = BaseLLMHTTPHandler()
|
||||
|
||||
# Test case 1: fake_stream is True
|
||||
stream = True
|
||||
data = {
|
||||
"stream": True,
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
fake_stream = True
|
||||
|
||||
result_stream, result_data = handler._prepare_fake_stream_request(
|
||||
stream=stream, data=data, fake_stream=fake_stream
|
||||
)
|
||||
|
||||
# Verify that stream is set to False
|
||||
assert result_stream is False
|
||||
# Verify that "stream" key is removed from data
|
||||
assert "stream" not in result_data
|
||||
# Verify other data remains unchanged
|
||||
assert result_data["model"] == "gpt-4"
|
||||
assert result_data["messages"] == [{"role": "user", "content": "Hello"}]
|
||||
|
||||
# Test case 2: fake_stream is False
|
||||
stream = True
|
||||
data = {
|
||||
"stream": True,
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
fake_stream = False
|
||||
|
||||
result_stream, result_data = handler._prepare_fake_stream_request(
|
||||
stream=stream, data=data, fake_stream=fake_stream
|
||||
)
|
||||
|
||||
# Verify that stream remains True
|
||||
assert result_stream is True
|
||||
# Verify that data remains unchanged
|
||||
assert "stream" in result_data
|
||||
assert result_data["stream"] is True
|
||||
assert result_data["model"] == "gpt-4"
|
||||
assert result_data["messages"] == [{"role": "user", "content": "Hello"}]
|
||||
|
||||
# Test case 3: data doesn't have stream key but fake_stream is True
|
||||
stream = True
|
||||
data = {"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}
|
||||
fake_stream = True
|
||||
|
||||
result_stream, result_data = handler._prepare_fake_stream_request(
|
||||
stream=stream, data=data, fake_stream=fake_stream
|
||||
)
|
||||
|
||||
# Verify that stream is set to False
|
||||
assert result_stream is False
|
||||
# Verify that data remains unchanged (since there was no stream key to remove)
|
||||
assert "stream" not in result_data
|
||||
assert result_data["model"] == "gpt-4"
|
||||
assert result_data["messages"] == [{"role": "user", "content": "Hello"}]
|
Reference in New Issue
Block a user