Files
Homelab/Development/litellm/tests/test_litellm/test_router.py

1693 lines
57 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
import json
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.router_utils.fallback_event_handlers import run_async_fallback
def test_update_kwargs_does_not_mutate_defaults_and_merges_metadata():
# initialize a real Router (envvars can be empty)
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-3",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
}
],
)
# override to known defaults for the test
router.default_litellm_params = {
"foo": "bar",
"metadata": {"baz": 123},
}
original = copy.deepcopy(router.default_litellm_params)
kwargs: dict = {}
# invoke the helper
router._update_kwargs_with_default_litellm_params(
kwargs=kwargs,
metadata_variable_name="litellm_metadata",
)
# 1) router.defaults must be unchanged
assert router.default_litellm_params == original
# 2) nonmetadata keys get merged
assert kwargs["foo"] == "bar"
# 3) metadata lands under "metadata"
assert kwargs["litellm_metadata"] == {"baz": 123}
def test_router_with_model_info_and_model_group():
"""
Test edge case where user specifies model_group in model_info
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
},
"model_info": {
"tpm": 1000,
"rpm": 1000,
"model_group": "gpt-3.5-turbo",
},
}
],
)
router._set_model_group_info(
model_group="gpt-3.5-turbo",
user_facing_model_group_name="gpt-3.5-turbo",
)
@pytest.mark.asyncio
async def test_arouter_with_tags_and_fallbacks():
"""
If fallback model missing tag, raise error
"""
from litellm import Router
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"mock_response": "Hello, world!",
"tags": ["test"],
},
},
{
"model_name": "anthropic-claude-3-5-sonnet",
"litellm_params": {
"model": "claude-3-5-sonnet-latest",
"mock_response": "Hello, world 2!",
},
},
],
fallbacks=[
{"gpt-3.5-turbo": ["anthropic-claude-3-5-sonnet"]},
],
enable_tag_filtering=True,
)
with pytest.raises(Exception):
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_testing_fallbacks=True,
metadata={"tags": ["test"]},
)
@pytest.mark.asyncio
async def test_async_router_acreate_file():
"""
Write to all deployments of a model
"""
from unittest.mock import MagicMock, call, patch
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
},
{"model_name": "gpt-3.5-turbo", "litellm_params": {"model": "gpt-4o-mini"}},
],
)
with patch("litellm.acreate_file", return_value=MagicMock()) as mock_acreate_file:
mock_acreate_file.return_value = MagicMock()
response = await router.acreate_file(
model="gpt-3.5-turbo",
purpose="test",
file=MagicMock(),
)
# assert that the mock_acreate_file was called twice
assert mock_acreate_file.call_count == 2
@pytest.mark.asyncio
async def test_async_router_acreate_file_with_jsonl():
"""
Test router.acreate_file with both JSONL and non-JSONL files
"""
import json
from io import BytesIO
from unittest.mock import MagicMock, patch
# Create test JSONL content
jsonl_data = [
{
"body": {
"model": "gpt-3.5-turbo-router",
"messages": [{"role": "user", "content": "test"}],
}
},
{
"body": {
"model": "gpt-3.5-turbo-router",
"messages": [{"role": "user", "content": "test2"}],
}
},
]
jsonl_content = "\n".join(json.dumps(item) for item in jsonl_data)
jsonl_file = BytesIO(jsonl_content.encode("utf-8"))
jsonl_file.name = "test.jsonl"
# Create test non-JSONL content
non_jsonl_content = "This is not a JSONL file"
non_jsonl_file = BytesIO(non_jsonl_content.encode("utf-8"))
non_jsonl_file.name = "test.txt"
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo-router",
"litellm_params": {"model": "gpt-3.5-turbo"},
},
{
"model_name": "gpt-3.5-turbo-router",
"litellm_params": {"model": "gpt-4o-mini"},
},
],
)
with patch("litellm.acreate_file", return_value=MagicMock()) as mock_acreate_file:
# Test with JSONL file
response = await router.acreate_file(
model="gpt-3.5-turbo-router",
purpose="batch",
file=jsonl_file,
)
# Verify mock was called twice (once for each deployment)
print(f"mock_acreate_file.call_count: {mock_acreate_file.call_count}")
print(f"mock_acreate_file.call_args_list: {mock_acreate_file.call_args_list}")
assert mock_acreate_file.call_count == 2
# Get the file content passed to the first call
first_call_file = mock_acreate_file.call_args_list[0][1]["file"]
first_call_content = first_call_file.read().decode("utf-8")
# Verify the model name was replaced in the JSONL content
first_line = json.loads(first_call_content.split("\n")[0])
assert first_line["body"]["model"] == "gpt-3.5-turbo"
# Reset mock for next test
mock_acreate_file.reset_mock()
# Test with non-JSONL file
response = await router.acreate_file(
model="gpt-3.5-turbo-router",
purpose="user_data",
file=non_jsonl_file,
)
# Verify mock was called twice
assert mock_acreate_file.call_count == 2
# Get the file content passed to the first call
first_call_file = mock_acreate_file.call_args_list[0][1]["file"]
first_call_content = first_call_file.read().decode("utf-8")
# Verify the non-JSONL content was not modified
assert first_call_content == non_jsonl_content
@pytest.mark.asyncio
async def test_arouter_async_get_healthy_deployments():
"""
Test that afile_content returns the correct file content
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
},
],
)
result = await router.async_get_healthy_deployments(
model="gpt-3.5-turbo",
request_kwargs={},
messages=None,
input=None,
specific_deployment=False,
parent_otel_span=None,
)
assert len(result) == 1
assert result[0]["model_name"] == "gpt-3.5-turbo"
assert result[0]["litellm_params"]["model"] == "gpt-3.5-turbo"
@pytest.mark.asyncio
@patch("litellm.amoderation")
async def test_arouter_amoderation_with_credential_name(mock_amoderation):
"""
Test that router.amoderation passes litellm_credential_name to the underlying litellm.amoderation call
"""
mock_amoderation.return_value = AsyncMock()
router = litellm.Router(
model_list=[
{
"model_name": "text-moderation-stable",
"litellm_params": {
"model": "text-moderation-stable",
"litellm_credential_name": "my-custom-auth",
},
},
],
)
await router.amoderation(input="I love everyone!", model="text-moderation-stable")
mock_amoderation.assert_called_once()
call_kwargs = mock_amoderation.call_args[1] # Get the kwargs of the call
print(
"call kwargs for router.amoderation=",
json.dumps(call_kwargs, indent=4, default=str),
)
assert call_kwargs["litellm_credential_name"] == "my-custom-auth"
assert call_kwargs["model"] == "text-moderation-stable"
def test_arouter_test_team_model():
"""
Test that router.test_team_model returns the correct model
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
"model_info": {
"team_id": "test-team",
"team_public_model_name": "test-model",
},
},
],
)
result = router.map_team_model(team_model_name="test-model", team_id="test-team")
assert result is not None
def test_arouter_ignore_invalid_deployments():
"""
Test that router.ignore_invalid_deployments is set to True
"""
from litellm.types.router import Deployment
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "my-bad-model"},
},
],
ignore_invalid_deployments=True,
)
assert router.ignore_invalid_deployments is True
assert router.get_model_list() == []
## check upsert deployment
router.upsert_deployment(
Deployment(
model_name="gpt-3.5-turbo",
litellm_params={"model": "my-bad-model"}, # type: ignore
model_info={"tpm": 1000, "rpm": 1000},
)
)
assert router.get_model_list() == []
@pytest.mark.asyncio
async def test_arouter_aretrieve_batch():
"""
Test that router.aretrieve_batch returns the correct response
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"custom_llm_provider": "azure",
"api_key": "my-custom-key",
"api_base": "my-custom-base",
},
}
],
)
with patch.object(
litellm, "aretrieve_batch", return_value=AsyncMock()
) as mock_aretrieve_batch:
try:
response = await router.aretrieve_batch(
model="gpt-3.5-turbo",
)
except Exception as e:
print(f"Error: {e}")
mock_aretrieve_batch.assert_called_once()
print(mock_aretrieve_batch.call_args.kwargs)
assert mock_aretrieve_batch.call_args.kwargs["api_key"] == "my-custom-key"
assert mock_aretrieve_batch.call_args.kwargs["api_base"] == "my-custom-base"
@pytest.mark.asyncio
async def test_arouter_aretrieve_file_content():
"""
Test that router.acreate_file with JSONL file returns the correct response
"""
with patch.object(
litellm, "afile_content", return_value=AsyncMock()
) as mock_afile_content:
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"custom_llm_provider": "azure",
"api_key": "my-custom-key",
"api_base": "my-custom-base",
},
}
],
)
try:
response = await router.afile_content(
**{
"model": "gpt-3.5-turbo",
"file_id": "my-unique-file-id",
}
) # type: ignore
except Exception as e:
print(f"Error: {e}")
mock_afile_content.assert_called_once()
print(mock_afile_content.call_args.kwargs)
assert mock_afile_content.call_args.kwargs["api_key"] == "my-custom-key"
assert mock_afile_content.call_args.kwargs["api_base"] == "my-custom-base"
@pytest.mark.asyncio
async def test_arouter_filter_team_based_models():
"""
Test that router.filter_team_based_models filters out models that are not in the team
"""
from litellm.types.router import Deployment
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
"model_info": {
"team_id": "test-team",
},
},
],
)
# WORKS
result = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, world!"}],
metadata={"user_api_key_team_id": "test-team"},
mock_response="Hello, world!",
)
assert result is not None
# FAILS
with pytest.raises(Exception) as e:
result = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, world!"}],
metadata={"user_api_key_team_id": "test-team-2"},
mock_response="Hello, world!",
)
assert "No deployments available" in str(e.value)
## ADD A MODEL THAT IS NOT IN THE TEAM
router.add_deployment(
Deployment(
model_name="gpt-3.5-turbo",
litellm_params={"model": "gpt-3.5-turbo"}, # type: ignore
model_info={"tpm": 1000, "rpm": 1000},
)
)
result = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, world!"}],
metadata={"user_api_key_team_id": "test-team-2"},
mock_response="Hello, world!",
)
assert result is not None
def test_arouter_should_include_deployment():
"""
Test the should_include_deployment method with various scenarios
The method logic:
1. Returns True if: team_id matches AND model_name matches team_public_model_name
2. Returns True if: model_name matches AND deployment has no team_id
3. Otherwise returns False
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
"model_info": {
"team_id": "test-team",
},
},
],
)
# Test deployment structures
deployment_with_team_and_public_name = {
"model_name": "gpt-3.5-turbo",
"model_info": {
"team_id": "test-team",
"team_public_model_name": "team-gpt-model",
},
}
deployment_with_team_no_public_name = {
"model_name": "gpt-3.5-turbo",
"model_info": {
"team_id": "test-team",
},
}
deployment_without_team = {
"model_name": "gpt-4",
"model_info": {},
}
deployment_different_team = {
"model_name": "claude-3",
"model_info": {
"team_id": "other-team",
"team_public_model_name": "team-claude-model",
},
}
# Test Case 1: Team-specific deployment - team_id and team_public_model_name match
result = router.should_include_deployment(
model_name="team-gpt-model",
model=deployment_with_team_and_public_name,
team_id="test-team",
)
assert (
result is True
), "Should return True when team_id and team_public_model_name match"
# Test Case 2: Team-specific deployment - team_id matches but model_name doesn't match team_public_model_name
result = router.should_include_deployment(
model_name="different-model",
model=deployment_with_team_and_public_name,
team_id="test-team",
)
assert (
result is False
), "Should return False when team_id matches but model_name doesn't match team_public_model_name"
# Test Case 3: Team-specific deployment - team_id doesn't match
result = router.should_include_deployment(
model_name="team-gpt-model",
model=deployment_with_team_and_public_name,
team_id="different-team",
)
assert result is False, "Should return False when team_id doesn't match"
# Test Case 4: Team-specific deployment with no team_public_model_name - should fail
result = router.should_include_deployment(
model_name="gpt-3.5-turbo",
model=deployment_with_team_no_public_name,
team_id="test-team",
)
assert (
result is True
), "Should return True when team deployment has no team_public_model_name to match"
# Test Case 5: Non-team deployment - model_name matches and no team_id
result = router.should_include_deployment(
model_name="gpt-4", model=deployment_without_team, team_id=None
)
assert (
result is True
), "Should return True when model_name matches and deployment has no team_id"
# Test Case 6: Non-team deployment - model_name matches but team_id provided (should still work)
result = router.should_include_deployment(
model_name="gpt-4", model=deployment_without_team, team_id="any-team"
)
assert (
result is True
), "Should return True when model_name matches non-team deployment, regardless of team_id param"
# Test Case 7: Non-team deployment - model_name doesn't match
result = router.should_include_deployment(
model_name="different-model", model=deployment_without_team, team_id=None
)
assert result is False, "Should return False when model_name doesn't match"
# Test Case 8: Team deployment accessed without matching team_id
result = router.should_include_deployment(
model_name="gpt-3.5-turbo",
model=deployment_with_team_and_public_name,
team_id=None,
)
assert (
result is True
), "Should return True when matching model with exact model_name"
def test_arouter_responses_api_bridge():
"""
Test that router.responses_api_bridge returns the correct response
"""
from unittest.mock import MagicMock, patch
from litellm.llms.custom_httpx.http_handler import HTTPHandler
router = litellm.Router(
model_list=[
{
"model_name": "[IP-approved] o3-pro",
"litellm_params": {
"model": "azure/responses/o_series/webinterface-o3-pro",
"api_base": "https://webhook.site/fba79dae-220a-4bb7-9a3a-8caa49604e55",
"api_key": "sk-1234567890",
"api_version": "preview",
"stream": True,
},
"model_info": {
"input_cost_per_token": 0.00002,
"output_cost_per_token": 0.00008,
},
}
],
)
## CONFIRM BRIDGE IS CALLED
with patch.object(litellm, "responses", return_value=AsyncMock()) as mock_responses:
result = router.completion(
model="[IP-approved] o3-pro",
messages=[{"role": "user", "content": "Hello, world!"}],
)
assert mock_responses.call_count == 1
## CONFIRM MODEL NAME IS STRIPPED
client = HTTPHandler()
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
try:
result = router.completion(
model="[IP-approved] o3-pro",
messages=[{"role": "user", "content": "Hello, world!"}],
client=client,
num_retries=0,
)
except Exception as e:
print(f"Error: {e}")
assert mock_post.call_count == 1
assert (
mock_post.call_args.kwargs["url"]
== "https://webhook.site/fba79dae-220a-4bb7-9a3a-8caa49604e55/openai/v1/responses?api-version=preview"
)
assert mock_post.call_args.kwargs["json"]["model"] == "webinterface-o3-pro"
@pytest.mark.asyncio
async def test_router_v1_messages_fallbacks():
"""
Test that router.v1_messages_fallbacks returns the correct response
"""
router = litellm.Router(
model_list=[
{
"model_name": "claude-3-5-sonnet-latest",
"litellm_params": {
"model": "anthropic/claude-3-5-sonnet-latest",
"mock_response": "litellm.InternalServerError",
},
},
{
"model_name": "bedrock-claude",
"litellm_params": {
"model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"mock_response": "Hello, world I am a fallback!",
},
},
],
fallbacks=[
{"claude-3-5-sonnet-latest": ["bedrock-claude"]},
],
)
result = await router.aanthropic_messages(
model="claude-3-5-sonnet-latest",
messages=[{"role": "user", "content": "Hello, world!"}],
max_tokens=256,
)
assert result is not None
print(result)
assert result["content"][0]["text"] == "Hello, world I am a fallback!"
def test_add_invalid_provider_to_router():
"""
Test that router.add_deployment raises an error if the provider is invalid
"""
from litellm.types.router import Deployment
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
],
)
with pytest.raises(Exception) as e:
router.add_deployment(
Deployment(
model_name="vertex_ai/*",
litellm_params={
"model": "vertex_ai/*",
"custom_llm_provider": "vertex_ai_eu",
},
)
)
assert router.pattern_router.patterns == {}
@pytest.mark.asyncio
async def test_router_ageneric_api_call_with_fallbacks_helper():
"""
Test the _ageneric_api_call_with_fallbacks_helper method with various scenarios
"""
from unittest.mock import AsyncMock, MagicMock, patch
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "test-key",
"api_base": "https://api.openai.com/v1",
},
"model_info": {
"tpm": 1000,
"rpm": 1000,
},
},
],
)
# Test 1: Successful call
async def mock_generic_function(**kwargs):
return {"result": "success", "model": kwargs.get("model")}
with patch.object(router, "async_get_available_deployment") as mock_get_deployment:
mock_get_deployment.return_value = {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "test-key",
"api_base": "https://api.openai.com/v1",
},
}
with patch.object(
router, "_update_kwargs_with_deployment"
) as mock_update_kwargs:
with patch.object(
router, "async_routing_strategy_pre_call_checks"
) as mock_pre_call_checks:
with patch.object(
router, "_get_client", return_value=None
) as mock_get_client:
result = await router._ageneric_api_call_with_fallbacks_helper(
model="gpt-3.5-turbo",
original_generic_function=mock_generic_function,
messages=[{"role": "user", "content": "test"}],
)
assert result is not None
assert result["result"] == "success"
mock_get_deployment.assert_called_once()
mock_update_kwargs.assert_called_once()
mock_pre_call_checks.assert_called_once()
# Test 2: Passthrough on no deployment (success case)
async def mock_passthrough_function(**kwargs):
return {"result": "passthrough", "model": kwargs.get("model")}
with patch.object(router, "async_get_available_deployment") as mock_get_deployment:
mock_get_deployment.side_effect = Exception("No deployment available")
result = await router._ageneric_api_call_with_fallbacks_helper(
model="gpt-3.5-turbo",
original_generic_function=mock_passthrough_function,
passthrough_on_no_deployment=True,
messages=[{"role": "user", "content": "test"}],
)
assert result is not None
assert result["result"] == "passthrough"
assert result["model"] == "gpt-3.5-turbo"
# Test 3: No deployment available and passthrough=False (should raise exception)
with patch.object(router, "async_get_available_deployment") as mock_get_deployment:
mock_get_deployment.side_effect = Exception("No deployment available")
with pytest.raises(Exception) as exc_info:
await router._ageneric_api_call_with_fallbacks_helper(
model="gpt-3.5-turbo",
original_generic_function=mock_generic_function,
passthrough_on_no_deployment=False,
messages=[{"role": "user", "content": "test"}],
)
assert "No deployment available" in str(exc_info.value)
# Test 4: Test with semaphore (rate limiting)
import asyncio
async def mock_semaphore_function(**kwargs):
return {"result": "semaphore_success", "model": kwargs.get("model")}
with patch.object(router, "async_get_available_deployment") as mock_get_deployment:
mock_get_deployment.return_value = {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "test-key",
"api_base": "https://api.openai.com/v1",
},
}
mock_semaphore = asyncio.Semaphore(1)
with patch.object(
router, "_update_kwargs_with_deployment"
) as mock_update_kwargs:
with patch.object(
router, "_get_client", return_value=mock_semaphore
) as mock_get_client:
with patch.object(
router, "async_routing_strategy_pre_call_checks"
) as mock_pre_call_checks:
result = await router._ageneric_api_call_with_fallbacks_helper(
model="gpt-3.5-turbo",
original_generic_function=mock_semaphore_function,
messages=[{"role": "user", "content": "test"}],
)
assert result is not None
assert result["result"] == "semaphore_success"
mock_get_client.assert_called_once()
mock_pre_call_checks.assert_called_once()
# Test 5: Test call tracking (success and failure counts)
initial_success_count = router.success_calls.get("gpt-3.5-turbo", 0)
initial_fail_count = router.fail_calls.get("gpt-3.5-turbo", 0)
async def mock_failing_function(**kwargs):
raise Exception("Mock failure")
with patch.object(router, "async_get_available_deployment") as mock_get_deployment:
mock_get_deployment.return_value = {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "test-key",
"api_base": "https://api.openai.com/v1",
},
}
with patch.object(
router, "_update_kwargs_with_deployment"
) as mock_update_kwargs:
with patch.object(
router, "_get_client", return_value=None
) as mock_get_client:
with patch.object(
router, "async_routing_strategy_pre_call_checks"
) as mock_pre_call_checks:
with pytest.raises(Exception) as exc_info:
await router._ageneric_api_call_with_fallbacks_helper(
model="gpt-3.5-turbo",
original_generic_function=mock_failing_function,
messages=[{"role": "user", "content": "test"}],
)
assert "Mock failure" in str(exc_info.value)
# Check that fail_calls was incremented
assert router.fail_calls["gpt-3.5-turbo"] == initial_fail_count + 1
@pytest.mark.asyncio
async def test_router_forward_client_headers_by_model_group():
"""
Test that router.forward_client_headers_by_model_group returns the correct response
"""
from unittest.mock import MagicMock, patch
from litellm.types.router import ModelGroupSettings
litellm.model_group_settings = ModelGroupSettings(
forward_client_headers_to_llm_api=[
"gpt-3.5-turbo-allow",
"openai/*",
"gpt-3.5-turbo-custom",
]
)
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo-allow",
"litellm_params": {
"model": "gpt-3.5-turbo",
},
},
{
"model_name": "gpt-3.5-turbo-disallow",
"litellm_params": {
"model": "gpt-3.5-turbo",
},
},
{
"model_name": "openai/*",
"litellm_params": {
"model": "openai/*",
},
},
{
"model_name": "openai/gpt-4o-mini",
"litellm_params": {
"model": "openai/gpt-4o-mini",
},
},
],
model_group_alias={
"gpt-3.5-turbo-custom": "gpt-3.5-turbo-disallow",
},
)
## Scenario 1: Direct model name
with patch.object(
litellm.main, "completion", return_value=MagicMock()
) as mock_completion:
await router.acompletion(
model="gpt-3.5-turbo-allow",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="Hello, world!",
secret_fields={"raw_headers": {"test": "test"}},
)
mock_completion.assert_called_once()
print(mock_completion.call_args.kwargs["headers"])
## Scenario 2: Wildcard model name
with patch.object(
litellm.main, "completion", return_value=MagicMock()
) as mock_completion:
await router.acompletion(
model="openai/gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="Hello, world!",
secret_fields={"raw_headers": {"test": "test"}},
)
mock_completion.assert_called_once()
print(mock_completion.call_args.kwargs["headers"])
## Scenario 3: Not in model_group_settings
with patch.object(
litellm.main, "completion", return_value=MagicMock()
) as mock_completion:
await router.acompletion(
model="openai/gpt-4o-mini",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="Hello, world!",
secret_fields={"raw_headers": {"test": "test"}},
)
mock_completion.assert_called_once()
assert mock_completion.call_args.kwargs.get("headers") is None
## Scenario 4: Model group alias
with patch.object(
litellm.main, "completion", return_value=MagicMock()
) as mock_completion:
await router.acompletion(
model="gpt-3.5-turbo-custom",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="Hello, world!",
secret_fields={"raw_headers": {"test": "test"}},
)
mock_completion.assert_called_once()
print(mock_completion.call_args.kwargs["headers"])
def test_router_apply_default_settings():
"""
Test that Router.apply_default_settings() adds the expected default pre-call checks
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
],
)
# Apply default settings
result = router.apply_default_settings()
# Verify the method returns None
assert result is None
# Verify that the forward_client_headers_by_model_group pre-call check was added
# Check if any callback is of the ForwardClientHeadersByModelGroupCheck type
has_forward_headers_check = False
for callback in litellm.callbacks:
print(callback)
print(f"callback.__class__: {callback.__class__}")
if hasattr(
callback, "__class__"
) and "ForwardClientSideHeadersByModelGroup" in str(callback.__class__):
has_forward_headers_check = True
break
assert (
has_forward_headers_check
), "Expected ForwardClientSideHeadersByModelGroup to be added to callbacks"
def test_router_get_model_access_groups_team_only_models():
"""
Test that Router.get_model_access_groups returns the correct response for team-only models
"""
router = litellm.Router(
model_list=[
{
"model_name": "my-custom-model-name",
"litellm_params": {"model": "gpt-3.5-turbo"},
"model_info": {
"team_id": "team_1",
"access_groups": ["default-models"],
"team_public_model_name": "gpt-3.5-turbo",
},
},
]
)
access_groups = router.get_model_access_groups(
model_name="gpt-3.5-turbo", team_id=None
)
assert len(access_groups) == 0
access_groups = router.get_model_access_groups(
model_name="gpt-3.5-turbo", team_id="team_1"
)
assert list(access_groups.keys()) == ["default-models"]
@pytest.mark.asyncio
async def test_acompletion_streaming_iterator():
"""Test _acompletion_streaming_iterator for normal streaming and fallback behavior."""
from unittest.mock import AsyncMock, MagicMock
from litellm.exceptions import MidStreamFallbackError
from litellm.types.utils import ModelResponseStream
# Helper class for creating async iterators
class AsyncIterator:
def __init__(self, items, error_after=None):
self.items = items
self.index = 0
self.error_after = error_after
def __aiter__(self):
return self
async def __anext__(self):
if self.error_after is not None and self.index >= self.error_after:
raise self.error_after
if self.index >= len(self.items):
raise StopAsyncIteration
item = self.items[self.index]
self.index += 1
return item
# Set up router with fallback configuration
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4", "api_key": "fake-key-1"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo", "api_key": "fake-key-2"},
},
],
fallbacks=[{"gpt-4": ["gpt-3.5-turbo"]}],
set_verbose=True,
)
# Test data
messages = [{"role": "user", "content": "Hello"}]
initial_kwargs = {"model": "gpt-4", "stream": True, "temperature": 0.7}
# Test 1: Successful streaming (no errors)
print("\n=== Test 1: Successful streaming ===")
# Mock successful streaming response
mock_chunks = [
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hello"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content=" there"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content="!"))]),
]
mock_response = AsyncIterator(mock_chunks)
setattr(mock_response, "model", "gpt-4")
setattr(mock_response, "custom_llm_provider", "openai")
setattr(mock_response, "logging_obj", MagicMock())
result = await router._acompletion_streaming_iterator(
model_response=mock_response, messages=messages, initial_kwargs=initial_kwargs
)
# Collect streamed chunks
collected_chunks = []
async for chunk in result:
collected_chunks.append(chunk)
assert len(collected_chunks) == 3
assert all(chunk in mock_chunks for chunk in collected_chunks)
print("✓ Successfully streamed all chunks")
# Test 2: MidStreamFallbackError with fallback
print("\n=== Test 2: MidStreamFallbackError with fallback ===")
# Create error that should trigger after first chunk
error = MidStreamFallbackError(
message="Connection lost",
model="gpt-4",
llm_provider="openai",
generated_content="Hello",
)
class AsyncIteratorWithError:
def __init__(self, items, error_after_index):
self.items = items
self.index = 0
self.error_after_index = error_after_index
self.chunks = []
def __aiter__(self):
return self
async def __anext__(self):
if self.index >= len(self.items):
raise StopAsyncIteration
if self.index == self.error_after_index:
raise error
item = self.items[self.index]
self.index += 1
return item
mock_error_response = AsyncIteratorWithError(
mock_chunks, 1
) # Error after first chunk
setattr(mock_error_response, "model", "gpt-4")
setattr(mock_error_response, "custom_llm_provider", "openai")
setattr(mock_error_response, "logging_obj", MagicMock())
# Mock the fallback response
fallback_chunks = [
MagicMock(choices=[MagicMock(delta=MagicMock(content=" world"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content="!"))]),
]
mock_fallback_response = AsyncIterator(fallback_chunks)
# Mock the fallback function
with patch.object(
router,
"async_function_with_fallbacks_common_utils",
return_value=mock_fallback_response,
) as mock_fallback_utils:
collected_chunks = []
result = await router._acompletion_streaming_iterator(
model_response=mock_error_response,
messages=messages,
initial_kwargs=initial_kwargs,
)
async for chunk in result:
collected_chunks.append(chunk)
# Verify fallback was called
assert mock_fallback_utils.called
call_args = mock_fallback_utils.call_args
# Check that generated content was added to messages
fallback_kwargs = call_args.kwargs["kwargs"]
modified_messages = fallback_kwargs["messages"]
# Should have original message + system message + assistant message with prefix
assert len(modified_messages) == 3
assert modified_messages[0] == {"role": "user", "content": "Hello"}
assert modified_messages[1]["role"] == "system"
assert "continuation" in modified_messages[1]["content"]
assert modified_messages[2]["role"] == "assistant"
assert modified_messages[2]["content"] == "Hello"
assert modified_messages[2]["prefix"] == True
# Verify fallback parameters
assert call_args.kwargs["disable_fallbacks"] == False
assert call_args.kwargs["model_group"] == "gpt-4"
# Should get original chunk + fallback chunks
assert len(collected_chunks) == 3 # 1 original + 2 fallback
print("✓ Fallback system called correctly with proper message modification")
print("\n=== All tests passed! ===")
@pytest.mark.asyncio
async def test_acompletion_streaming_iterator_edge_cases():
"""Test edge cases for _acompletion_streaming_iterator."""
from unittest.mock import MagicMock
from litellm.exceptions import MidStreamFallbackError
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4", "api_key": "fake-key"},
}
],
set_verbose=True,
)
messages = [{"role": "user", "content": "Test"}]
initial_kwargs = {"model": "gpt-4", "stream": True}
# Test: Empty generated content
empty_error = MidStreamFallbackError(
message="Error",
model="gpt-4",
llm_provider="openai",
generated_content="", # Empty content
)
class AsyncIteratorImmediateError:
def __init__(self):
self.model = "gpt-4"
self.custom_llm_provider = "openai"
self.logging_obj = MagicMock()
self.chunks = []
def __aiter__(self):
return self
async def __anext__(self):
raise empty_error
mock_response = AsyncIteratorImmediateError()
# Mock empty fallback response using AsyncIterator
class EmptyAsyncIterator:
def __aiter__(self):
return self
async def __anext__(self):
raise StopAsyncIteration
mock_fallback_response = EmptyAsyncIterator()
with patch.object(
router,
"async_function_with_fallbacks_common_utils",
return_value=mock_fallback_response,
) as mock_fallback_utils:
collected_chunks = []
iterator = await router._acompletion_streaming_iterator(
model_response=mock_response,
messages=messages,
initial_kwargs=initial_kwargs,
)
async for chunk in iterator:
collected_chunks.append(chunk)
# Should still call fallback even with empty content
assert mock_fallback_utils.called
fallback_kwargs = mock_fallback_utils.call_args.kwargs["kwargs"]
modified_messages = fallback_kwargs["messages"]
# Should have assistant message with empty content
assert modified_messages[2]["content"] == ""
print("✓ Handles empty generated content correctly")
print("✓ Edge case tests passed!")
@pytest.mark.asyncio
async def test_async_function_with_fallbacks_common_utils():
"""Test the async_function_with_fallbacks_common_utils method"""
# Create a basic router for testing
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
},
}
],
max_fallbacks=5,
)
# Test case 1: disable_fallbacks=True should raise original exception
test_exception = Exception("Test error")
with pytest.raises(Exception, match="Test error"):
await router.async_function_with_fallbacks_common_utils(
e=test_exception,
disable_fallbacks=True,
fallbacks=None,
context_window_fallbacks=None,
content_policy_fallbacks=None,
model_group="gpt-3.5-turbo",
args=(),
kwargs=MagicMock(),
)
# Test case 2: original_model_group=None should raise original exception
with pytest.raises(Exception, match="Test error"):
await router.async_function_with_fallbacks_common_utils(
e=test_exception,
disable_fallbacks=False,
fallbacks=None,
context_window_fallbacks=None,
content_policy_fallbacks=None,
model_group="gpt-3.5-turbo",
args=(),
kwargs={}, # No model key
)
def test_should_include_deployment():
"""Test that Router.should_include_deployment returns the correct response"""
router = litellm.Router(
model_list=[
{
"model_name": "model_name_a28a12f9-3e44-4861-bd4f-325f2d309ce8_cd5dc6fb-b046-4e05-ae1d-32ba4d936266",
"litellm_params": {"model": "openai/*"},
"model_info": {
"team_id": "a28a12f9-3e44-4861-bd4f-325f2d309ce8",
"team_public_model_name": "openai/*",
},
}
],
)
model = {
"model_name": "model_name_a28a12f9-3e44-4861-bd4f-325f2d309ce8_cd5dc6fb-b046-4e05-ae1d-32ba4d936266",
"litellm_params": {
"api_key": "sk-proj-1234567890",
"custom_llm_provider": "openai",
"use_in_pass_through": False,
"use_litellm_proxy": False,
"merge_reasoning_content_in_choices": False,
"model": "openai/*",
},
"model_info": {
"id": "95f58039-d54a-4d1c-b700-5e32e99a1120",
"db_model": True,
"updated_by": "64a2f787-0863-4d76-9516-2dc49c1598e8",
"created_by": "64a2f787-0863-4d76-9516-2dc49c1598e8",
"team_id": "a28a12f9-3e44-4861-bd4f-325f2d309ce8",
"team_public_model_name": "openai/*",
"mode": "completion",
"access_groups": ["restricted-models-openai"],
},
}
model_name = "openai/o4-mini-deep-research"
team_id = "a28a12f9-3e44-4861-bd4f-325f2d309ce8"
assert router.get_model_list(
model_name=model_name,
team_id=team_id,
)
def test_get_deployment_model_info_base_model_flow():
"""Test that get_deployment_model_info correctly handles the base model flow"""
from unittest.mock import patch
router = litellm.Router(
model_list=[
{
"model_name": "test-model",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
],
)
# Mock data for the test
mock_custom_model_info = {
"base_model": "gpt-3.5-turbo",
"input_cost_per_token": 0.001,
"output_cost_per_token": 0.002,
"custom_field": "custom_value",
}
mock_base_model_info = {
"key": "gpt-3.5-turbo",
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0015, # This should be overridden by custom model info
"output_cost_per_token": 0.002,
"litellm_provider": "openai",
"mode": "chat",
"supported_openai_params": ["temperature", "max_tokens"],
}
mock_litellm_model_name_info = {
"key": "test-model",
"max_tokens": 2048,
"max_input_tokens": 2048,
"max_output_tokens": 2048,
"input_cost_per_token": 0.0005,
"output_cost_per_token": 0.001,
"litellm_provider": "test_provider",
"mode": "completion",
"supported_openai_params": ["temperature"],
}
# Test Case 1: Base model flow with custom model info that has base_model
with patch.object(
litellm, "model_cost", {"test-custom-model": mock_custom_model_info}
):
with patch.object(litellm, "get_model_info") as mock_get_model_info:
# Configure mock returns
mock_get_model_info.side_effect = lambda model: {
"gpt-3.5-turbo": mock_base_model_info,
"test-model": mock_litellm_model_name_info,
}.get(model)
result = router.get_deployment_model_info(
model_id="test-custom-model", model_name="test-model"
)
# Verify that get_model_info was called for both base model and model name
assert mock_get_model_info.call_count == 2
mock_get_model_info.assert_any_call(
model="gpt-3.5-turbo"
) # base model call
mock_get_model_info.assert_any_call(model="test-model") # model name call
# Verify the result contains merged information
assert result is not None
# Test the correct merging behavior after fix:
# 1. base_model_info provides defaults, custom_model_info overrides (correct priority)
# 2. The result of step 1 gets merged into litellm_model_name_info (custom+base override litellm)
# Fields from custom model (should override base model values)
assert (
result["input_cost_per_token"] == 0.001
) # From custom model (overrides base 0.0015)
assert (
result["output_cost_per_token"] == 0.002
) # From custom model (same as base)
assert result["custom_field"] == "custom_value" # From custom model
# Fields from base model that weren't overridden by custom
assert result["max_tokens"] == 4096 # From base model
assert result["litellm_provider"] == "openai" # From base model
assert (
result["mode"] == "chat"
) # From base model (overrides litellm "completion")
# The key field comes from base model since both base and litellm have it
# and base model info overrides litellm model name info in final merge
assert (
result["key"] == "gpt-3.5-turbo"
) # From base model (overrides litellm key)
# Test Case 2: Custom model info without base_model
mock_custom_model_info_no_base = {
"input_cost_per_token": 0.001,
"output_cost_per_token": 0.002,
"custom_field": "custom_value",
}
with patch.object(
litellm,
"model_cost",
{"test-custom-model-no-base": mock_custom_model_info_no_base},
):
with patch.object(litellm, "get_model_info") as mock_get_model_info:
mock_get_model_info.side_effect = lambda model: {
"test-model": mock_litellm_model_name_info,
}.get(model)
result = router.get_deployment_model_info(
model_id="test-custom-model-no-base", model_name="test-model"
)
# Should only call get_model_info once for model name (no base model)
assert mock_get_model_info.call_count == 1
mock_get_model_info.assert_called_with(model="test-model")
# Verify the result contains merged information
assert result is not None
assert result["input_cost_per_token"] == 0.001 # From custom model
assert result["max_tokens"] == 2048 # From litellm model name info
assert result["custom_field"] == "custom_value" # From custom model
assert result["mode"] == "completion" # From litellm model name info
# Test Case 3: No custom model info, only litellm model name info
with patch.object(litellm, "model_cost", {}): # Empty model cost
with patch.object(litellm, "get_model_info") as mock_get_model_info:
mock_get_model_info.side_effect = lambda model: {
"test-model": mock_litellm_model_name_info,
}.get(model)
result = router.get_deployment_model_info(
model_id="non-existent-model", model_name="test-model"
)
# Should only call get_model_info once for model name
assert mock_get_model_info.call_count == 1
mock_get_model_info.assert_called_with(model="test-model")
# Result should be just the litellm model name info
assert result is not None
assert result == mock_litellm_model_name_info
# Test Case 4: Base model info retrieval fails (exception handling)
mock_custom_model_info_invalid_base = {
"base_model": "invalid-base-model",
"input_cost_per_token": 0.001,
"output_cost_per_token": 0.002,
}
with patch.object(
litellm,
"model_cost",
{"test-custom-model-invalid": mock_custom_model_info_invalid_base},
):
with patch.object(litellm, "get_model_info") as mock_get_model_info:
# Mock get_model_info to raise exception for invalid base model
def mock_get_model_info_side_effect(model):
if model == "invalid-base-model":
raise Exception("Model not found")
elif model == "test-model":
return mock_litellm_model_name_info
return None
mock_get_model_info.side_effect = mock_get_model_info_side_effect
result = router.get_deployment_model_info(
model_id="test-custom-model-invalid", model_name="test-model"
)
# Should handle exception gracefully and still return merged result
assert result is not None
assert result["input_cost_per_token"] == 0.001 # From custom model
assert result["mode"] == "completion" # From litellm model name info
# Test Case 5: Both model_cost.get() and get_model_info() return None
with patch.object(litellm, "model_cost", {}):
with patch.object(
litellm, "get_model_info", side_effect=Exception("Not found")
):
result = router.get_deployment_model_info(
model_id="non-existent", model_name="non-existent"
)
# Should return None when no model info is found
assert result is None
print("✓ All base model flow test cases passed!")
@patch("litellm.model_cost", {})
def test_get_deployment_model_info_base_model_merge_priority():
"""Test that base model info merging respects the correct priority order"""
from unittest.mock import patch
router = litellm.Router(
model_list=[
{
"model_name": "test-model",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
],
)
# Test data with overlapping fields to test merge priority
mock_custom_model_info = {
"base_model": "gpt-4",
"input_cost_per_token": 0.01, # Should override base model value
"max_tokens": 8000, # Should override base model value
"custom_only_field": "custom_value",
}
mock_base_model_info = {
"key": "gpt-4",
"max_tokens": 4096, # Should be overridden by custom model
"input_cost_per_token": 0.03, # Should be overridden by custom model
"output_cost_per_token": 0.06, # Should be preserved (not in custom)
"litellm_provider": "openai",
"base_only_field": "base_value",
}
mock_litellm_model_name_info = {
"key": "test-model",
"max_tokens": 2048, # Should be overridden by final custom model info
"input_cost_per_token": 0.005, # Should be overridden by final custom model info
"output_cost_per_token": 0.01, # Should be overridden by final custom model info
"mode": "completion",
"litellm_only_field": "litellm_value",
}
with patch.object(
litellm, "model_cost", {"custom-model-id": mock_custom_model_info}
):
with patch.object(litellm, "get_model_info") as mock_get_model_info:
mock_get_model_info.side_effect = lambda model: {
"gpt-4": mock_base_model_info,
"test-model": mock_litellm_model_name_info,
}.get(model)
result = router.get_deployment_model_info(
model_id="custom-model-id", model_name="test-model"
)
assert result is not None
# Test correct merge priority after fix:
# 1. base_model_info provides defaults
# 2. custom_model_info overrides base_model_info
# 3. Result from steps 1-2 overrides litellm_model_name_info
# Fields that should come from custom model info (highest priority)
assert (
result["input_cost_per_token"] == 0.01
) # From custom model (overrides base 0.03)
assert (
result["max_tokens"] == 8000
) # From custom model (overrides base 4096)
assert result["custom_only_field"] == "custom_value" # From custom model
# Fields that should come from base model (not overridden by custom)
assert (
result["output_cost_per_token"] == 0.06
) # From base model (not in custom)
assert (
result["litellm_provider"] == "openai"
) # From base model (not in custom)
assert (
result["base_only_field"] == "base_value"
) # From base model (not in custom)
# Fields that should come from litellm model name info (not overridden by custom+base)
assert (
result["mode"] == "completion"
) # From litellm model name info (not in custom or base)
assert (
result["litellm_only_field"] == "litellm_value"
) # From litellm model name info (not in custom or base)
# Key comes from base model since both base and litellm have key fields
# and the merged custom+base overrides litellm in the final merge
assert result["key"] == "gpt-4"
print("✓ Base model merge priority test passed!")