Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,154 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.anthropic_passthrough_logging_handler import (
|
||||
AnthropicPassthroughLoggingHandler,
|
||||
)
|
||||
|
||||
|
||||
class TestAnthropicLoggingHandlerModelFallback:
|
||||
"""Test the model fallback logic in the anthropic passthrough logging handler."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures"""
|
||||
self.start_time = datetime.now()
|
||||
self.end_time = datetime.now()
|
||||
self.mock_chunks = [
|
||||
'{"type": "message_start", "message": {"id": "msg_123", "model": "claude-3-haiku-20240307"}}',
|
||||
'{"type": "content_block_delta", "delta": {"text": "Hello"}}',
|
||||
'{"type": "content_block_delta", "delta": {"text": " world"}}',
|
||||
'{"type": "message_stop"}',
|
||||
]
|
||||
|
||||
def _create_mock_logging_obj(self, model_in_details: str = None) -> LiteLLMLoggingObj:
|
||||
"""Create a mock logging object with optional model in model_call_details"""
|
||||
mock_logging_obj = MagicMock()
|
||||
|
||||
if model_in_details:
|
||||
# Create a dict-like mock that returns the model for the 'model' key
|
||||
mock_model_call_details = {'model': model_in_details}
|
||||
mock_logging_obj.model_call_details = mock_model_call_details
|
||||
else:
|
||||
# Create empty dict or None
|
||||
mock_logging_obj.model_call_details = {}
|
||||
|
||||
return mock_logging_obj
|
||||
|
||||
def _create_mock_passthrough_handler(self):
|
||||
"""Create a mock passthrough success handler"""
|
||||
mock_handler = MagicMock()
|
||||
return mock_handler
|
||||
|
||||
|
||||
|
||||
@patch.object(AnthropicPassthroughLoggingHandler, '_build_complete_streaming_response')
|
||||
@patch.object(AnthropicPassthroughLoggingHandler, '_create_anthropic_response_logging_payload')
|
||||
def test_model_from_request_body_used_when_present(self, mock_create_payload, mock_build_response):
|
||||
"""Test that model from request_body is used when present"""
|
||||
# Arrange
|
||||
request_body = {"model": "claude-3-sonnet-20240229"}
|
||||
logging_obj = self._create_mock_logging_obj(model_in_details="claude-3-haiku-20240307")
|
||||
passthrough_handler = self._create_mock_passthrough_handler()
|
||||
|
||||
# Mock successful response building
|
||||
mock_build_response.return_value = MagicMock()
|
||||
mock_create_payload.return_value = {"test": "payload"}
|
||||
|
||||
# Act
|
||||
result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
|
||||
litellm_logging_obj=logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_handler,
|
||||
url_route="/anthropic/v1/messages",
|
||||
request_body=request_body,
|
||||
endpoint_type="messages",
|
||||
start_time=self.start_time,
|
||||
all_chunks=self.mock_chunks,
|
||||
end_time=self.end_time,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
# Verify that _build_complete_streaming_response was called with the request_body model
|
||||
mock_build_response.assert_called_once()
|
||||
call_args = mock_build_response.call_args
|
||||
assert call_args[1]['model'] == "claude-3-sonnet-20240229" # Should use request_body model
|
||||
|
||||
def test_model_fallback_logic_isolated(self):
|
||||
"""Test just the model fallback logic in isolation"""
|
||||
# Test case 1: Model from request body
|
||||
request_body = {"model": "claude-3-sonnet-20240229"}
|
||||
logging_obj = self._create_mock_logging_obj(model_in_details="claude-3-haiku-20240307")
|
||||
|
||||
# Extract the logic directly from the function
|
||||
model = request_body.get("model", "")
|
||||
if not model and hasattr(logging_obj, 'model_call_details') and logging_obj.model_call_details.get('model'):
|
||||
model = logging_obj.model_call_details.get('model')
|
||||
|
||||
assert model == "claude-3-sonnet-20240229" # Should use request_body model
|
||||
|
||||
# Test case 2: Fallback to logging obj
|
||||
request_body = {}
|
||||
logging_obj = self._create_mock_logging_obj(model_in_details="claude-3-haiku-20240307")
|
||||
|
||||
model = request_body.get("model", "")
|
||||
if not model and hasattr(logging_obj, 'model_call_details') and logging_obj.model_call_details.get('model'):
|
||||
model = logging_obj.model_call_details.get('model')
|
||||
|
||||
assert model == "claude-3-haiku-20240307" # Should use fallback model
|
||||
|
||||
# Test case 3: Empty string in request body, fallback to logging obj
|
||||
request_body = {"model": ""}
|
||||
logging_obj = self._create_mock_logging_obj(model_in_details="claude-3-opus-20240229")
|
||||
|
||||
model = request_body.get("model", "")
|
||||
if not model and hasattr(logging_obj, 'model_call_details') and logging_obj.model_call_details.get('model'):
|
||||
model = logging_obj.model_call_details.get('model')
|
||||
|
||||
assert model == "claude-3-opus-20240229" # Should use fallback model
|
||||
|
||||
# Test case 4: Both empty
|
||||
request_body = {}
|
||||
logging_obj = self._create_mock_logging_obj()
|
||||
|
||||
model = request_body.get("model", "")
|
||||
if not model and hasattr(logging_obj, 'model_call_details') and logging_obj.model_call_details.get('model'):
|
||||
model = logging_obj.model_call_details.get('model')
|
||||
|
||||
assert model == "" # Should be empty
|
||||
|
||||
def test_edge_case_missing_model_call_details_attribute(self):
|
||||
"""Test fallback behavior when logging_obj doesn't have model_call_details attribute"""
|
||||
# Case where logging_obj doesn't have the attribute at all
|
||||
request_body = {"model": ""} # Empty model in request body
|
||||
logging_obj = MagicMock()
|
||||
# Remove the attribute to simulate it not existing
|
||||
if hasattr(logging_obj, 'model_call_details'):
|
||||
delattr(logging_obj, 'model_call_details')
|
||||
|
||||
# Extract the logic directly from the function
|
||||
model = request_body.get("model", "")
|
||||
if not model and hasattr(logging_obj, 'model_call_details') and logging_obj.model_call_details.get('model'):
|
||||
model = logging_obj.model_call_details.get('model')
|
||||
|
||||
assert model == "" # Should remain empty since no fallback available
|
||||
|
||||
# Case where model_call_details exists but get returns None
|
||||
request_body = {"model": ""}
|
||||
logging_obj = self._create_mock_logging_obj() # Empty dict
|
||||
|
||||
model = request_body.get("model", "")
|
||||
if not model and hasattr(logging_obj, 'model_call_details') and logging_obj.model_call_details.get('model'):
|
||||
model = logging_obj.model_call_details.get('model')
|
||||
|
||||
assert model == "" # Should remain empty
|
@@ -0,0 +1,855 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from unittest import mock
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi import Request, Response
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
BaseOpenAIPassThroughHandler,
|
||||
RouteChecks,
|
||||
create_pass_through_route,
|
||||
vertex_discovery_proxy_route,
|
||||
vertex_proxy_route,
|
||||
)
|
||||
from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials
|
||||
|
||||
|
||||
class TestBaseOpenAIPassThroughHandler:
|
||||
def test_join_url_paths(self):
|
||||
print("\nTesting _join_url_paths method...")
|
||||
|
||||
# Test joining base URL with no path and a path
|
||||
base_url = httpx.URL("https://api.example.com")
|
||||
path = "/v1/chat/completions"
|
||||
result = BaseOpenAIPassThroughHandler._join_url_paths(
|
||||
base_url, path, litellm.LlmProviders.OPENAI.value
|
||||
)
|
||||
print(f"Base URL with no path: '{base_url}' + '{path}' → '{result}'")
|
||||
assert str(result) == "https://api.example.com/v1/chat/completions"
|
||||
|
||||
# Test joining base URL with path and another path
|
||||
base_url = httpx.URL("https://api.example.com/v1")
|
||||
path = "/chat/completions"
|
||||
result = BaseOpenAIPassThroughHandler._join_url_paths(
|
||||
base_url, path, litellm.LlmProviders.OPENAI.value
|
||||
)
|
||||
print(f"Base URL with path: '{base_url}' + '{path}' → '{result}'")
|
||||
assert str(result) == "https://api.example.com/v1/chat/completions"
|
||||
|
||||
# Test with path not starting with slash
|
||||
base_url = httpx.URL("https://api.example.com/v1")
|
||||
path = "chat/completions"
|
||||
result = BaseOpenAIPassThroughHandler._join_url_paths(
|
||||
base_url, path, litellm.LlmProviders.OPENAI.value
|
||||
)
|
||||
print(f"Path without leading slash: '{base_url}' + '{path}' → '{result}'")
|
||||
assert str(result) == "https://api.example.com/v1/chat/completions"
|
||||
|
||||
# Test with base URL having trailing slash
|
||||
base_url = httpx.URL("https://api.example.com/v1/")
|
||||
path = "/chat/completions"
|
||||
result = BaseOpenAIPassThroughHandler._join_url_paths(
|
||||
base_url, path, litellm.LlmProviders.OPENAI.value
|
||||
)
|
||||
print(f"Base URL with trailing slash: '{base_url}' + '{path}' → '{result}'")
|
||||
assert str(result) == "https://api.example.com/v1/chat/completions"
|
||||
|
||||
def test_append_openai_beta_header(self):
|
||||
print("\nTesting _append_openai_beta_header method...")
|
||||
|
||||
# Create mock requests with different paths
|
||||
assistants_request = MagicMock(spec=Request)
|
||||
assistants_request.url = MagicMock()
|
||||
assistants_request.url.path = "/v1/threads/thread_123456/messages"
|
||||
|
||||
non_assistants_request = MagicMock(spec=Request)
|
||||
non_assistants_request.url = MagicMock()
|
||||
non_assistants_request.url.path = "/v1/chat/completions"
|
||||
|
||||
headers = {"authorization": "Bearer test_key"}
|
||||
|
||||
# Test with assistants API request
|
||||
result = BaseOpenAIPassThroughHandler._append_openai_beta_header(
|
||||
headers, assistants_request
|
||||
)
|
||||
print(f"Assistants API request: Added header: {result}")
|
||||
assert result["OpenAI-Beta"] == "assistants=v2"
|
||||
|
||||
# Test with non-assistants API request
|
||||
headers = {"authorization": "Bearer test_key"}
|
||||
result = BaseOpenAIPassThroughHandler._append_openai_beta_header(
|
||||
headers, non_assistants_request
|
||||
)
|
||||
print(f"Non-assistants API request: Headers: {result}")
|
||||
assert "OpenAI-Beta" not in result
|
||||
|
||||
# Test with assistant in the path
|
||||
assistant_request = MagicMock(spec=Request)
|
||||
assistant_request.url = MagicMock()
|
||||
assistant_request.url.path = "/v1/assistants/asst_123456"
|
||||
|
||||
headers = {"authorization": "Bearer test_key"}
|
||||
result = BaseOpenAIPassThroughHandler._append_openai_beta_header(
|
||||
headers, assistant_request
|
||||
)
|
||||
print(f"Assistant API request: Added header: {result}")
|
||||
assert result["OpenAI-Beta"] == "assistants=v2"
|
||||
|
||||
def test_assemble_headers(self):
|
||||
print("\nTesting _assemble_headers method...")
|
||||
|
||||
# Mock request
|
||||
mock_request = MagicMock(spec=Request)
|
||||
api_key = "test_api_key"
|
||||
|
||||
# Patch the _append_openai_beta_header method to avoid testing it again
|
||||
with patch.object(
|
||||
BaseOpenAIPassThroughHandler,
|
||||
"_append_openai_beta_header",
|
||||
return_value={
|
||||
"authorization": "Bearer test_api_key",
|
||||
"api-key": "test_api_key",
|
||||
"test-header": "value",
|
||||
},
|
||||
):
|
||||
result = BaseOpenAIPassThroughHandler._assemble_headers(
|
||||
api_key, mock_request
|
||||
)
|
||||
print(f"Assembled headers: {result}")
|
||||
assert result["authorization"] == "Bearer test_api_key"
|
||||
assert result["api-key"] == "test_api_key"
|
||||
assert result["test-header"] == "value"
|
||||
|
||||
@patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||
)
|
||||
async def test_base_openai_pass_through_handler(self, mock_create_pass_through):
|
||||
print("\nTesting _base_openai_pass_through_handler method...")
|
||||
|
||||
# Mock dependencies
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.query_params = {"model": "gpt-4"}
|
||||
mock_response = MagicMock(spec=Response)
|
||||
mock_user_api_key_dict = MagicMock()
|
||||
|
||||
# Mock the endpoint function returned by create_pass_through_route
|
||||
mock_endpoint_func = MagicMock()
|
||||
mock_endpoint_func.return_value = {"result": "success"}
|
||||
mock_create_pass_through.return_value = mock_endpoint_func
|
||||
|
||||
print("Testing standard endpoint pass-through...")
|
||||
# Test with standard endpoint
|
||||
result = await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler(
|
||||
endpoint="/chat/completions",
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
base_target_url="https://api.openai.com",
|
||||
api_key="test_api_key",
|
||||
custom_llm_provider=litellm.LlmProviders.OPENAI.value,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
print(f"Result from handler: {result}")
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# Verify create_pass_through_route was called with correct parameters
|
||||
call_args = mock_create_pass_through.call_args[1]
|
||||
print(
|
||||
f"create_pass_through_route called with endpoint: {call_args['endpoint']}"
|
||||
)
|
||||
print(f"create_pass_through_route called with target: {call_args['target']}")
|
||||
assert call_args["endpoint"] == "/chat/completions"
|
||||
assert call_args["target"] == "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
# Verify endpoint_func was called with correct parameters
|
||||
print("Verifying endpoint_func call parameters...")
|
||||
call_kwargs = mock_endpoint_func.call_args[1]
|
||||
print(f"stream parameter: {call_kwargs['stream']}")
|
||||
print(f"query_params: {call_kwargs['query_params']}")
|
||||
assert call_kwargs["stream"] is False
|
||||
assert call_kwargs["query_params"] == {"model": "gpt-4"}
|
||||
|
||||
|
||||
class TestVertexAIPassThroughHandler:
|
||||
"""
|
||||
Case 1: User set passthrough credentials - confirm credentials used.
|
||||
|
||||
Case 2: User set default credentials, no exact passthrough credentials - confirm default credentials used.
|
||||
|
||||
Case 3: No default credentials, no mapped credentials - request passed through directly.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_passthrough_with_credentials(self, monkeypatch):
|
||||
"""
|
||||
Test that when passthrough credentials are set, they are correctly used in the request
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||
PassthroughEndpointRouter,
|
||||
)
|
||||
|
||||
vertex_project = "test-project"
|
||||
vertex_location = "us-central1"
|
||||
vertex_credentials = "test-creds"
|
||||
|
||||
pass_through_router = PassthroughEndpointRouter()
|
||||
|
||||
pass_through_router.add_vertex_credentials(
|
||||
project_id=vertex_project,
|
||||
location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
|
||||
pass_through_router,
|
||||
)
|
||||
|
||||
endpoint = f"/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/gemini-1.5-flash:generateContent"
|
||||
|
||||
# Mock request
|
||||
mock_request = Mock()
|
||||
mock_request.method = "POST"
|
||||
mock_request.headers = {
|
||||
"Authorization": "Bearer test-creds",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
mock_request.url = Mock()
|
||||
mock_request.url.path = endpoint
|
||||
|
||||
# Mock response
|
||||
mock_response = Response()
|
||||
|
||||
# Mock vertex credentials
|
||||
test_project = vertex_project
|
||||
test_location = vertex_location
|
||||
test_token = vertex_credentials
|
||||
|
||||
with mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
|
||||
) as mock_ensure_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
|
||||
) as mock_get_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||
) as mock_create_route, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.get_litellm_virtual_key"
|
||||
) as mock_get_virtual_key, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.user_api_key_auth"
|
||||
) as mock_user_auth:
|
||||
# Setup mocks
|
||||
mock_ensure_token.return_value = ("test-auth-header", test_project)
|
||||
mock_get_token.return_value = (test_token, "")
|
||||
mock_get_virtual_key.return_value = "Bearer test-key"
|
||||
mock_user_auth.return_value = {"api_key": "test-key"}
|
||||
|
||||
# Mock create_pass_through_route to return a function that returns a mock response
|
||||
mock_endpoint_func = AsyncMock(return_value={"status": "success"})
|
||||
mock_create_route.return_value = mock_endpoint_func
|
||||
|
||||
# Call the route
|
||||
try:
|
||||
result = await vertex_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
user_api_key_dict={"api_key": "test-key"},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
# Verify create_pass_through_route was called with correct arguments
|
||||
mock_create_route.assert_called_once_with(
|
||||
endpoint=endpoint,
|
||||
target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||
custom_headers={"Authorization": f"Bearer {test_token}"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_passthrough_with_global_location(self, monkeypatch):
|
||||
"""
|
||||
Test that when global location is used, it is correctly handled in the request
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||
PassthroughEndpointRouter,
|
||||
)
|
||||
|
||||
vertex_project = "test-project"
|
||||
vertex_location = "global"
|
||||
vertex_credentials = "test-creds"
|
||||
|
||||
pass_through_router = PassthroughEndpointRouter()
|
||||
|
||||
pass_through_router.add_vertex_credentials(
|
||||
project_id=vertex_project,
|
||||
location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
|
||||
pass_through_router,
|
||||
)
|
||||
|
||||
endpoint = f"/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/gemini-1.5-flash:generateContent"
|
||||
|
||||
# Mock request
|
||||
mock_request = Mock()
|
||||
mock_request.method = "POST"
|
||||
mock_request.headers = {
|
||||
"Authorization": "Bearer test-creds",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
mock_request.url = Mock()
|
||||
mock_request.url.path = endpoint
|
||||
|
||||
# Mock response
|
||||
mock_response = Response()
|
||||
|
||||
# Mock vertex credentials
|
||||
test_project = vertex_project
|
||||
test_location = vertex_location
|
||||
test_token = vertex_credentials
|
||||
|
||||
with mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
|
||||
) as mock_ensure_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
|
||||
) as mock_get_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||
) as mock_create_route, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.get_litellm_virtual_key"
|
||||
) as mock_get_virtual_key, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.user_api_key_auth"
|
||||
) as mock_user_auth:
|
||||
# Setup mocks
|
||||
mock_ensure_token.return_value = ("test-auth-header", test_project)
|
||||
mock_get_token.return_value = (test_token, "")
|
||||
mock_get_virtual_key.return_value = "Bearer test-key"
|
||||
mock_user_auth.return_value = {"api_key": "test-key"}
|
||||
|
||||
# Mock create_pass_through_route to return a function that returns a mock response
|
||||
mock_endpoint_func = AsyncMock(return_value={"status": "success"})
|
||||
mock_create_route.return_value = mock_endpoint_func
|
||||
|
||||
# Call the route
|
||||
try:
|
||||
result = await vertex_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
user_api_key_dict={"api_key": "test-key"},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
# Verify create_pass_through_route was called with correct arguments
|
||||
mock_create_route.assert_called_once_with(
|
||||
endpoint=endpoint,
|
||||
target=f"https://aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||
custom_headers={"Authorization": f"Bearer {test_token}"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"initial_endpoint",
|
||||
[
|
||||
"publishers/google/models/gemini-1.5-flash:generateContent",
|
||||
"v1/projects/bad-project/locations/bad-location/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_passthrough_with_default_credentials(
|
||||
self, monkeypatch, initial_endpoint
|
||||
):
|
||||
"""
|
||||
Test that when no passthrough credentials are set, default credentials are used in the request
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||
PassthroughEndpointRouter,
|
||||
)
|
||||
|
||||
# Setup default credentials
|
||||
default_project = "default-project"
|
||||
default_location = "us-central1"
|
||||
default_credentials = "default-creds"
|
||||
|
||||
pass_through_router = PassthroughEndpointRouter()
|
||||
pass_through_router.default_vertex_config = VertexPassThroughCredentials(
|
||||
vertex_project=default_project,
|
||||
vertex_location=default_location,
|
||||
vertex_credentials=default_credentials,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
|
||||
pass_through_router,
|
||||
)
|
||||
|
||||
# Use different project/location in request than the default
|
||||
endpoint = initial_endpoint
|
||||
|
||||
mock_request = Request(
|
||||
scope={
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"path": f"/vertex_ai/{endpoint}",
|
||||
"headers": {},
|
||||
}
|
||||
)
|
||||
mock_response = Response()
|
||||
|
||||
with mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
|
||||
) as mock_ensure_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
|
||||
) as mock_get_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||
) as mock_create_route:
|
||||
mock_ensure_token.return_value = ("test-auth-header", default_project)
|
||||
mock_get_token.return_value = (default_credentials, "")
|
||||
|
||||
try:
|
||||
await vertex_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Error: {e}")
|
||||
|
||||
# Verify default credentials were used
|
||||
mock_create_route.assert_called_once_with(
|
||||
endpoint=endpoint,
|
||||
target=f"https://{default_location}-aiplatform.googleapis.com/v1/projects/{default_project}/locations/{default_location}/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||
custom_headers={"Authorization": f"Bearer {default_credentials}"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_passthrough_with_no_default_credentials(self, monkeypatch):
|
||||
"""
|
||||
Test that when no default credentials are set, the request fails
|
||||
"""
|
||||
"""
|
||||
Test that when passthrough credentials are set, they are correctly used in the request
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||
PassthroughEndpointRouter,
|
||||
)
|
||||
|
||||
vertex_project = "my-project"
|
||||
vertex_location = "us-central1"
|
||||
vertex_credentials = "test-creds"
|
||||
|
||||
test_project = "test-project"
|
||||
test_location = "test-location"
|
||||
test_token = "test-creds"
|
||||
|
||||
pass_through_router = PassthroughEndpointRouter()
|
||||
|
||||
pass_through_router.add_vertex_credentials(
|
||||
project_id=vertex_project,
|
||||
location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
|
||||
pass_through_router,
|
||||
)
|
||||
|
||||
endpoint = f"/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent"
|
||||
|
||||
# Mock request
|
||||
mock_request = Request(
|
||||
scope={
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"path": endpoint,
|
||||
"headers": [
|
||||
(b"authorization", b"Bearer test-creds"),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Mock response
|
||||
mock_response = Response()
|
||||
|
||||
with mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
|
||||
) as mock_ensure_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
|
||||
) as mock_get_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||
) as mock_create_route:
|
||||
mock_ensure_token.return_value = ("test-auth-header", test_project)
|
||||
mock_get_token.return_value = (test_token, "")
|
||||
|
||||
# Call the route
|
||||
try:
|
||||
await vertex_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Error: {e}")
|
||||
|
||||
# Verify create_pass_through_route was called with correct arguments
|
||||
mock_create_route.assert_called_once_with(
|
||||
endpoint=endpoint,
|
||||
target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||
custom_headers={"authorization": f"Bearer {test_token}"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_vertex_proxy_route_api_key_auth(self):
|
||||
"""
|
||||
Critical
|
||||
|
||||
This is how Vertex AI JS SDK will Auth to Litellm Proxy
|
||||
"""
|
||||
# Mock dependencies
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {"x-litellm-api-key": "test-key-123"}
|
||||
mock_request.method = "POST"
|
||||
mock_response = Mock()
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.user_api_key_auth"
|
||||
) as mock_auth:
|
||||
mock_auth.return_value = {"api_key": "test-key-123"}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||
) as mock_pass_through:
|
||||
mock_pass_through.return_value = AsyncMock(
|
||||
return_value={"status": "success"}
|
||||
)
|
||||
|
||||
# Call the function
|
||||
result = await vertex_proxy_route(
|
||||
endpoint="v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent",
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
)
|
||||
|
||||
# Verify user_api_key_auth was called with the correct Bearer token
|
||||
mock_auth.assert_called_once()
|
||||
call_args = mock_auth.call_args[1]
|
||||
assert call_args["api_key"] == "Bearer test-key-123"
|
||||
|
||||
def test_vertex_passthrough_handler_multimodal_embedding_response(self):
|
||||
"""
|
||||
Test that vertex_passthrough_handler correctly identifies and processes multimodal embedding responses
|
||||
"""
|
||||
import datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.vertex_passthrough_logging_handler import (
|
||||
VertexPassthroughLoggingHandler,
|
||||
)
|
||||
|
||||
# Create mock multimodal embedding response data
|
||||
multimodal_response_data = {
|
||||
"predictions": [
|
||||
{
|
||||
"textEmbedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"imageEmbedding": [0.6, 0.7, 0.8, 0.9, 1.0],
|
||||
},
|
||||
{
|
||||
"videoEmbeddings": [
|
||||
{
|
||||
"embedding": [0.11, 0.22, 0.33, 0.44, 0.55],
|
||||
"startOffsetSec": 0,
|
||||
"endOffsetSec": 5
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Create mock httpx.Response
|
||||
mock_httpx_response = Mock()
|
||||
mock_httpx_response.json.return_value = multimodal_response_data
|
||||
mock_httpx_response.status_code = 200
|
||||
|
||||
# Create mock logging object
|
||||
mock_logging_obj = Mock(spec=LiteLLMLoggingObj)
|
||||
mock_logging_obj.litellm_call_id = "test-call-id-123"
|
||||
mock_logging_obj.model_call_details = {}
|
||||
|
||||
# Test URL with multimodal embedding model
|
||||
url_route = "/v1/projects/test-project/locations/us-central1/publishers/google/models/multimodalembedding@001:predict"
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
end_time = datetime.datetime.now()
|
||||
|
||||
with patch("litellm.llms.vertex_ai.multimodal_embeddings.transformation.VertexAIMultimodalEmbeddingConfig") as mock_multimodal_config:
|
||||
# Mock the multimodal config instance and its methods
|
||||
mock_config_instance = Mock()
|
||||
mock_multimodal_config.return_value = mock_config_instance
|
||||
|
||||
# Create a mock embedding response that would be returned by the transformation
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
mock_embedding_response = EmbeddingResponse(
|
||||
object="list",
|
||||
data=[
|
||||
Embedding(embedding=[0.1, 0.2, 0.3, 0.4, 0.5], index=0, object="embedding"),
|
||||
Embedding(embedding=[0.6, 0.7, 0.8, 0.9, 1.0], index=1, object="embedding"),
|
||||
],
|
||||
model="multimodalembedding@001",
|
||||
usage=Usage(prompt_tokens=0, total_tokens=0, completion_tokens=0)
|
||||
)
|
||||
mock_config_instance.transform_embedding_response.return_value = mock_embedding_response
|
||||
|
||||
# Call the handler
|
||||
result = VertexPassthroughLoggingHandler.vertex_passthrough_handler(
|
||||
httpx_response=mock_httpx_response,
|
||||
logging_obj=mock_logging_obj,
|
||||
url_route=url_route,
|
||||
result="test-result",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=False
|
||||
)
|
||||
|
||||
# Verify multimodal embedding detection and processing
|
||||
assert result is not None
|
||||
assert "result" in result
|
||||
assert "kwargs" in result
|
||||
|
||||
# Verify that the multimodal config was instantiated and used
|
||||
mock_multimodal_config.assert_called_once()
|
||||
mock_config_instance.transform_embedding_response.assert_called_once()
|
||||
|
||||
# Verify the response is an EmbeddingResponse
|
||||
assert isinstance(result["result"], EmbeddingResponse)
|
||||
assert result["result"].model == "multimodalembedding@001"
|
||||
assert len(result["result"].data) == 2
|
||||
|
||||
def test_vertex_passthrough_handler_multimodal_detection_method(self):
|
||||
"""
|
||||
Test the _is_multimodal_embedding_response detection method specifically
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.vertex_passthrough_logging_handler import (
|
||||
VertexPassthroughLoggingHandler,
|
||||
)
|
||||
|
||||
# Test case 1: Response with textEmbedding should be detected as multimodal
|
||||
response_with_text_embedding = {
|
||||
"predictions": [
|
||||
{
|
||||
"textEmbedding": [0.1, 0.2, 0.3]
|
||||
}
|
||||
]
|
||||
}
|
||||
assert VertexPassthroughLoggingHandler._is_multimodal_embedding_response(response_with_text_embedding) is True
|
||||
|
||||
# Test case 2: Response with imageEmbedding should be detected as multimodal
|
||||
response_with_image_embedding = {
|
||||
"predictions": [
|
||||
{
|
||||
"imageEmbedding": [0.4, 0.5, 0.6]
|
||||
}
|
||||
]
|
||||
}
|
||||
assert VertexPassthroughLoggingHandler._is_multimodal_embedding_response(response_with_image_embedding) is True
|
||||
|
||||
# Test case 3: Response with videoEmbeddings should be detected as multimodal
|
||||
response_with_video_embeddings = {
|
||||
"predictions": [
|
||||
{
|
||||
"videoEmbeddings": [
|
||||
{
|
||||
"embedding": [0.7, 0.8, 0.9],
|
||||
"startOffsetSec": 0,
|
||||
"endOffsetSec": 5
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
assert VertexPassthroughLoggingHandler._is_multimodal_embedding_response(response_with_video_embeddings) is True
|
||||
|
||||
# Test case 4: Regular text embedding response should NOT be detected as multimodal
|
||||
regular_embedding_response = {
|
||||
"predictions": [
|
||||
{
|
||||
"embeddings": {
|
||||
"values": [0.1, 0.2, 0.3]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
assert VertexPassthroughLoggingHandler._is_multimodal_embedding_response(regular_embedding_response) is False
|
||||
|
||||
# Test case 5: Non-embedding response should NOT be detected as multimodal
|
||||
non_embedding_response = {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [{"text": "Hello world"}]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
assert VertexPassthroughLoggingHandler._is_multimodal_embedding_response(non_embedding_response) is False
|
||||
|
||||
# Test case 6: Empty response should NOT be detected as multimodal
|
||||
empty_response = {}
|
||||
assert VertexPassthroughLoggingHandler._is_multimodal_embedding_response(empty_response) is False
|
||||
|
||||
|
||||
class TestVertexAIDiscoveryPassThroughHandler:
|
||||
"""
|
||||
Test cases for Vertex AI Discovery passthrough endpoint
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_discovery_passthrough_with_credentials(self, monkeypatch):
|
||||
"""
|
||||
Test that when passthrough credentials are set, they are correctly used in the request
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||
PassthroughEndpointRouter,
|
||||
)
|
||||
|
||||
vertex_project = "test-project"
|
||||
vertex_location = "us-central1"
|
||||
vertex_credentials = "test-creds"
|
||||
|
||||
pass_through_router = PassthroughEndpointRouter()
|
||||
|
||||
pass_through_router.add_vertex_credentials(
|
||||
project_id=vertex_project,
|
||||
location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
|
||||
pass_through_router,
|
||||
)
|
||||
|
||||
endpoint = f"/v1/projects/{vertex_project}/locations/{vertex_location}/dataStores/default/servingConfigs/default:search"
|
||||
|
||||
# Mock request
|
||||
mock_request = Mock()
|
||||
mock_request.method = "POST"
|
||||
mock_request.headers = {
|
||||
"Authorization": "Bearer test-creds",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
mock_request.url = Mock()
|
||||
mock_request.url.path = endpoint
|
||||
|
||||
# Mock response
|
||||
mock_response = Response()
|
||||
|
||||
# Mock vertex credentials
|
||||
test_project = vertex_project
|
||||
test_location = vertex_location
|
||||
test_token = vertex_credentials
|
||||
|
||||
with mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
|
||||
) as mock_ensure_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
|
||||
) as mock_get_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||
) as mock_create_route, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.get_litellm_virtual_key"
|
||||
) as mock_get_virtual_key, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.user_api_key_auth"
|
||||
) as mock_user_auth:
|
||||
# Setup mocks
|
||||
mock_ensure_token.return_value = ("test-auth-header", test_project)
|
||||
mock_get_token.return_value = (test_token, "")
|
||||
mock_get_virtual_key.return_value = "Bearer test-key"
|
||||
mock_user_auth.return_value = {"api_key": "test-key"}
|
||||
|
||||
# Mock create_pass_through_route to return a function that returns a mock response
|
||||
mock_endpoint_func = AsyncMock(return_value={"status": "success"})
|
||||
mock_create_route.return_value = mock_endpoint_func
|
||||
|
||||
# Call the route
|
||||
try:
|
||||
result = await vertex_discovery_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
# Verify create_pass_through_route was called with correct arguments
|
||||
mock_create_route.assert_called_once_with(
|
||||
endpoint=endpoint,
|
||||
target=f"https://discoveryengine.googleapis.com/v1/projects/{test_project}/locations/{test_location}/dataStores/default/servingConfigs/default:search",
|
||||
custom_headers={"Authorization": f"Bearer {test_token}"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_discovery_proxy_route_api_key_auth(self):
|
||||
"""
|
||||
Test that the route correctly handles API key authentication
|
||||
"""
|
||||
# Mock dependencies
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {"x-litellm-api-key": "test-key-123"}
|
||||
mock_request.method = "POST"
|
||||
mock_response = Mock()
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.user_api_key_auth"
|
||||
) as mock_auth:
|
||||
mock_auth.return_value = {"api_key": "test-key-123"}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||
) as mock_pass_through:
|
||||
mock_pass_through.return_value = AsyncMock(
|
||||
return_value={"status": "success"}
|
||||
)
|
||||
|
||||
# Call the function
|
||||
result = await vertex_discovery_proxy_route(
|
||||
endpoint="v1/projects/test-project/locations/us-central1/dataStores/default/servingConfigs/default:search",
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
)
|
||||
|
||||
# Verify user_api_key_auth was called with the correct Bearer token
|
||||
mock_auth.assert_called_once()
|
||||
call_args = mock_auth.call_args[1]
|
||||
assert call_args["api_key"] == "Bearer test-key-123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_streaming_request_fn():
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
is_streaming_request_fn,
|
||||
)
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.method = "POST"
|
||||
mock_request.headers = {"content-type": "multipart/form-data"}
|
||||
mock_request.form = AsyncMock(return_value={"stream": "true"})
|
||||
assert await is_streaming_request_fn(mock_request) is True
|
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,44 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi import Request, Response
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from litellm.proxy.pass_through_endpoints.common_utils import get_litellm_virtual_key
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_litellm_virtual_key():
|
||||
"""
|
||||
Test that the get_litellm_virtual_key function correctly handles the API key authentication
|
||||
"""
|
||||
# Test with x-litellm-api-key
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {"x-litellm-api-key": "test-key-123"}
|
||||
result = get_litellm_virtual_key(mock_request)
|
||||
assert result == "Bearer test-key-123"
|
||||
|
||||
# Test with Authorization header
|
||||
mock_request.headers = {"Authorization": "Bearer auth-key-456"}
|
||||
result = get_litellm_virtual_key(mock_request)
|
||||
assert result == "Bearer auth-key-456"
|
||||
|
||||
# Test with both headers (x-litellm-api-key should take precedence)
|
||||
mock_request.headers = {
|
||||
"x-litellm-api-key": "test-key-123",
|
||||
"Authorization": "Bearer auth-key-456",
|
||||
}
|
||||
result = get_litellm_virtual_key(mock_request)
|
||||
assert result == "Bearer test-key-123"
|
Reference in New Issue
Block a user