Added LiteLLM to the stack

This commit is contained in:
2025-08-18 09:40:50 +00:00
parent 0648c1968c
commit d220b04e32
2682 changed files with 533609 additions and 1 deletions

View File

@@ -0,0 +1,202 @@
"""
Test adding a pass through assemblyai model + api key + api base to the db
wait 20 seconds
make request
Cases to cover
1. user points api base to <proxy-base>/assemblyai
2. user points api base to <proxy-base>/asssemblyai/us
3. user points api base to <proxy-base>/assemblyai/eu
4. Bad API Key / credential - 401
"""
import time
import assemblyai as aai
import pytest
import httpx
import os
import json
TEST_MASTER_KEY = "sk-1234"
PROXY_BASE_URL = "http://0.0.0.0:4000"
US_BASE_URL = f"{PROXY_BASE_URL}/assemblyai"
EU_BASE_URL = f"{PROXY_BASE_URL}/eu.assemblyai"
ASSEMBLYAI_API_KEY_ENV_VAR = "TEST_SPECIAL_ASSEMBLYAI_API_KEY"
def _delete_all_assemblyai_models_from_db():
"""
Delete all assemblyai models from the db
"""
print("Deleting all assemblyai models from the db.......")
model_list_response = httpx.get(
url=f"{PROXY_BASE_URL}/v2/model/info",
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
)
response_data = model_list_response.json()
print("model list response", json.dumps(response_data, indent=4, default=str))
# Filter for only AssemblyAI models
assemblyai_models = [
model
for model in response_data["data"]
if model.get("litellm_params", {}).get("custom_llm_provider") == "assemblyai"
]
for model in assemblyai_models:
model_id = model["model_info"]["id"]
httpx.post(
url=f"{PROXY_BASE_URL}/model/delete",
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
json={"id": model_id},
)
print("Deleted all assemblyai models from the db")
@pytest.fixture(autouse=True)
def cleanup_assemblyai_models():
"""
Fixture to clean up AssemblyAI models before and after each test
"""
# Clean up before test
_delete_all_assemblyai_models_from_db()
# Run the test
yield
# Clean up after test
_delete_all_assemblyai_models_from_db()
def test_e2e_assemblyai_passthrough():
"""
Test adding a pass through assemblyai model + api key + api base to the db
wait 20 seconds
make request
"""
add_assembly_ai_model_to_db(api_base="https://api.assemblyai.com")
virtual_key = create_virtual_key()
# make request
make_assemblyai_basic_transcribe_request(
virtual_key=virtual_key, assemblyai_base_url=US_BASE_URL
)
pass
def test_e2e_assemblyai_passthrough_eu():
"""
Test adding a pass through assemblyai model + api key + api base to the db
wait 20 seconds
make request
"""
add_assembly_ai_model_to_db(api_base="https://api.eu.assemblyai.com")
virtual_key = create_virtual_key()
# make request
make_assemblyai_basic_transcribe_request(
virtual_key=virtual_key, assemblyai_base_url=EU_BASE_URL
)
pass
def test_assemblyai_routes_with_bad_api_key():
"""
Test AssemblyAI endpoints with invalid API key to ensure proper error handling
"""
bad_api_key = "sk-12222"
payload = {
"audio_url": "https://assembly.ai/wildfires.mp3",
"audio_end_at": 280,
"audio_start_from": 10,
"auto_chapters": True,
}
headers = {
"Authorization": f"Bearer {bad_api_key}",
"Content-Type": "application/json",
}
# Test EU endpoint
eu_response = httpx.post(
f"{PROXY_BASE_URL}/eu.assemblyai/v2/transcript", headers=headers, json=payload
)
assert (
eu_response.status_code == 401
), f"Expected 401 unauthorized, got {eu_response.status_code}"
# Test US endpoint
us_response = httpx.post(
f"{PROXY_BASE_URL}/assemblyai/v2/transcript", headers=headers, json=payload
)
assert (
us_response.status_code == 401
), f"Expected 401 unauthorized, got {us_response.status_code}"
def create_virtual_key():
"""
Create a virtual key
"""
response = httpx.post(
url=f"{PROXY_BASE_URL}/key/generate",
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
json={},
)
print(response.json())
return response.json()["key"]
def add_assembly_ai_model_to_db(
api_base: str,
):
"""
Add the assemblyai model to the db - makes a http request to the /model/new endpoint on PROXY_BASE_URL
"""
print("assmbly ai api key", os.getenv(ASSEMBLYAI_API_KEY_ENV_VAR))
response = httpx.post(
url=f"{PROXY_BASE_URL}/model/new",
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
json={
"model_name": "assemblyai/*",
"litellm_params": {
"model": "assemblyai/*",
"custom_llm_provider": "assemblyai",
"api_key": os.getenv(ASSEMBLYAI_API_KEY_ENV_VAR),
"api_base": api_base,
"use_in_pass_through": True,
},
"model_info": {},
},
)
print(response.json())
pass
def make_assemblyai_basic_transcribe_request(
virtual_key: str, assemblyai_base_url: str
):
print("making basic transcribe request to assemblyai passthrough")
# Replace with your API key
aai.settings.api_key = f"Bearer {virtual_key}"
aai.settings.base_url = assemblyai_base_url
# URL of the file to transcribe
FILE_URL = "https://assembly.ai/wildfires.mp3"
# You can also transcribe a local file by passing in a file path
# FILE_URL = './path/to/file.mp3'
transcriber = aai.Transcriber()
transcript = transcriber.transcribe(FILE_URL)
print(transcript)
print(transcript.id)
if transcript.id:
transcript.delete_by_id(transcript.id)
else:
pytest.fail("Failed to get transcript id")
if transcript.status == aai.TranscriptStatus.error:
print(transcript.error)
pytest.fail(f"Failed to transcribe file error: {transcript.error}")
else:
print(transcript.text)

View File

@@ -0,0 +1,93 @@
"""
PROD TEST - DO NOT Delete this Test
e2e test for langfuse callback in DB
- Add langfuse callback to DB - with /config/update
- wait 20 seconds for the callback to be loaded into the instance
- Make a /chat/completions request to the proxy
- Check if the request is logged in Langfuse
"""
import pytest
import asyncio
import aiohttp
import os
import dotenv
from dotenv import load_dotenv
import pytest
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion
load_dotenv()
# used for testing
LANGFUSE_BASE_URL = "https://exampleopenaiendpoint-production-c715.up.railway.app"
async def config_update(session, routing_strategy=None):
url = "http://0.0.0.0:4000/config/update"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
print("routing_strategy: ", routing_strategy)
data = {
"litellm_settings": {"success_callback": ["langfuse"]},
"environment_variables": {
"LANGFUSE_PUBLIC_KEY": "any-public-key",
"LANGFUSE_SECRET_KEY": "any-secret-key",
"LANGFUSE_HOST": LANGFUSE_BASE_URL,
},
}
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(response_text)
print("status: ", status)
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
async def check_langfuse_request(response_id: str):
async with aiohttp.ClientSession() as session:
url = f"{LANGFUSE_BASE_URL}/langfuse/trace/{response_id}"
async with session.get(url) as response:
response_json = await response.json()
assert response.status == 200, f"Expected status 200, got {response.status}"
assert (
response_json["exists"] == True
), f"Request {response_id} not found in Langfuse traces"
assert response_json["request_id"] == response_id, f"Request ID mismatch"
async def make_chat_completions_request() -> ChatCompletion:
client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
response = await client.chat.completions.create(
model="fake-openai-endpoint",
messages=[{"role": "user", "content": "Hello, world!"}],
)
print(response)
return response
@pytest.mark.asyncio
async def test_e2e_langfuse_callbacks_in_db():
session = aiohttp.ClientSession()
# add langfuse callback to DB
await config_update(session)
# wait 20 seconds for the callback to be loaded into the instance
await asyncio.sleep(20)
# make a /chat/completions request to the proxy
response = await make_chat_completions_request()
print(response)
response_id = response.id
print("response_id: ", response_id)
await asyncio.sleep(11)
# check if the request is logged in Langfuse
await check_langfuse_request(response_id)

View File

@@ -0,0 +1,325 @@
from datetime import datetime
from typing import List, Optional
import pytest
import uuid
import os
import asyncio
from unittest import mock
from fastapi.testclient import TestClient
from fastapi import FastAPI
from starlette import status
from litellm.constants import LITELLM_PROXY_ADMIN_NAME
from litellm.proxy._types import MCPSpecVersion, MCPSpecVersionType, MCPTransportType, MCPTransport, NewMCPServerRequest, LiteLLM_MCPServerTable, LitellmUserRoles, UserAPIKeyAuth
from litellm.types.mcp import MCPAuth
from litellm.proxy.management_endpoints.mcp_management_endpoints import does_mcp_server_exist
TEST_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-1234")
def generate_mcpserver_record(url: Optional[str] = None,
transport: Optional[MCPTransportType] = None,
spec_version: Optional[MCPSpecVersionType] = None) -> LiteLLM_MCPServerTable:
"""
Generate a mock record for testing.
"""
now = datetime.now()
return LiteLLM_MCPServerTable(
server_id=str(uuid.uuid4()),
alias="Test Server",
url=url or "http://localhost.com:8080/mcp",
transport=transport or MCPTransport.sse,
spec_version=spec_version or MCPSpecVersion.mar_2025,
created_at=now,
updated_at=now,
)
# Cheers SO
def is_valid_uuid(val):
try:
uuid.UUID(str(val))
return True
except ValueError:
return False
def generate_mcpserver_create_request(
server_id: Optional[str] = None,
url: Optional[str] = None,
transport: Optional[MCPTransportType] = None,
spec_version: Optional[MCPSpecVersionType] = None) -> NewMCPServerRequest:
"""
Generate a mock create request for testing.
"""
return NewMCPServerRequest(
server_id=server_id,
alias="Test Server",
url=url or "http://localhost.com:8080/mcp",
transport=transport or MCPTransport.sse,
spec_version=spec_version or MCPSpecVersion.mar_2025,
)
def assert_mcp_server_record_same(mcp_server: NewMCPServerRequest, resp: LiteLLM_MCPServerTable):
"""
Assert that the mcp server record is created correctly.
"""
if mcp_server.server_id is not None:
assert resp.server_id == mcp_server.server_id
else:
assert is_valid_uuid(resp.server_id)
assert resp.alias == mcp_server.alias
assert resp.url == mcp_server.url
assert resp.description == mcp_server.description
assert resp.transport == mcp_server.transport
assert resp.spec_version == mcp_server.spec_version
assert resp.auth_type == mcp_server.auth_type
assert resp.created_at is not None
assert resp.updated_at is not None
assert resp.created_by == LITELLM_PROXY_ADMIN_NAME
assert resp.updated_by == LITELLM_PROXY_ADMIN_NAME
def test_does_mcp_server_exist():
"""
Unit Test if the MCP server exists in the list.
"""
mcp_server_records: List[LiteLLM_MCPServerTable] = [generate_mcpserver_record(), generate_mcpserver_record()]
# test all records are found
for record in mcp_server_records:
assert does_mcp_server_exist(mcp_server_records, record.server_id)
# test record not found
not_found_record = str(uuid.uuid4())
assert False == does_mcp_server_exist(mcp_server_records, not_found_record)
@pytest.mark.asyncio
async def test_create_mcp_server_direct():
"""
Direct test of the MCP server creation logic without HTTP calls.
"""
# Mock the database functions directly
with mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.MCP_AVAILABLE", True), \
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw") as mock_get_prisma, \
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.create_mcp_server", new_callable=mock.AsyncMock) as mock_create, \
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.get_mcp_server", new_callable=mock.AsyncMock) as mock_get_server, \
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager") as mock_manager:
# Import after mocking
from litellm.proxy.management_endpoints.mcp_management_endpoints import add_mcp_server
# Mock database client
mock_prisma = mock.Mock()
mock_get_prisma.return_value = mock_prisma
# Mock server manager
mock_manager.add_update_server = mock.Mock()
mock_manager.reload_servers_from_database = mock.AsyncMock()
# Set up test data
server_id = str(uuid.uuid4())
mcp_server_request = generate_mcpserver_create_request(server_id=server_id)
# The function will normalize the alias by replacing spaces with underscores
expected_alias = mcp_server_request.alias.replace(' ', '_') if mcp_server_request.alias else None
expected_response = LiteLLM_MCPServerTable(
server_id=server_id,
alias=expected_alias, # Use the normalized alias
description=mcp_server_request.description,
url=mcp_server_request.url,
transport=mcp_server_request.transport,
spec_version=mcp_server_request.spec_version,
auth_type=mcp_server_request.auth_type,
created_at=datetime.now(),
updated_at=datetime.now(),
created_by=LITELLM_PROXY_ADMIN_NAME,
updated_by=LITELLM_PROXY_ADMIN_NAME,
teams=[]
)
# Mock the database calls
mock_get_server.return_value = None # Server doesn't exist yet
# Set up async mock for create_mcp_server using AsyncMock
mock_create.return_value = expected_response
# Create mock user auth
user_auth = UserAPIKeyAuth(
api_key=TEST_MASTER_KEY,
user_id="test-user",
user_role=LitellmUserRoles.PROXY_ADMIN
)
# Call the function directly
result = await add_mcp_server(
payload=mcp_server_request,
user_api_key_dict=user_auth
)
# Verify the result
assert result.server_id == server_id
assert result.alias == expected_alias # Check against normalized alias
assert result.url == mcp_server_request.url
assert result.transport == mcp_server_request.transport
assert result.spec_version == mcp_server_request.spec_version
# Verify mocks were called
mock_get_server.assert_called_once_with(mock_prisma, server_id)
mock_create.assert_called_once()
mock_manager.add_update_server.assert_called_once_with(expected_response)
@pytest.mark.asyncio
async def test_create_duplicate_mcp_server():
"""
Test that creating a duplicate MCP server fails appropriately.
"""
# Mock the database functions directly
with mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.MCP_AVAILABLE", True), \
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw") as mock_get_prisma, \
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.get_mcp_server", new_callable=mock.AsyncMock) as mock_get_server:
# Import after mocking
from litellm.proxy.management_endpoints.mcp_management_endpoints import add_mcp_server
from fastapi import HTTPException
# Mock database client
mock_prisma = mock.Mock()
mock_get_prisma.return_value = mock_prisma
# Set up test data
server_id = str(uuid.uuid4())
mcp_server_request = generate_mcpserver_create_request(server_id=server_id)
existing_server = LiteLLM_MCPServerTable(
server_id=server_id,
alias="Existing Server",
url="http://existing.com",
transport=MCPTransport.sse,
spec_version=MCPSpecVersion.mar_2025,
created_at=datetime.now(),
updated_at=datetime.now(),
teams=[]
)
# Mock that server already exists
mock_get_server.return_value = existing_server
# Create mock user auth
user_auth = UserAPIKeyAuth(
api_key=TEST_MASTER_KEY,
user_id="test-user",
user_role=LitellmUserRoles.PROXY_ADMIN
)
# Expect HTTPException to be raised
with pytest.raises(HTTPException) as exc_info:
await add_mcp_server(
payload=mcp_server_request,
user_api_key_dict=user_auth
)
# Verify the exception details
assert exc_info.value.status_code == 400
assert "already exists" in str(exc_info.value.detail)
@pytest.mark.asyncio
async def test_create_mcp_server_auth_failure():
"""
Test that non-admin users cannot create MCP servers.
"""
# Mock the database functions directly
with mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.MCP_AVAILABLE", True), \
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw") as mock_get_prisma:
# Import after mocking
from litellm.proxy.management_endpoints.mcp_management_endpoints import add_mcp_server
from fastapi import HTTPException
# Mock database client
mock_prisma = mock.Mock()
mock_get_prisma.return_value = mock_prisma
# Set up test data
server_id = str(uuid.uuid4())
mcp_server_request = generate_mcpserver_create_request(server_id=server_id)
# Create mock user auth without admin role
user_auth = UserAPIKeyAuth(
api_key=TEST_MASTER_KEY,
user_id="test-user",
user_role=LitellmUserRoles.INTERNAL_USER # Not an admin
)
# Expect HTTPException to be raised
with pytest.raises(HTTPException) as exc_info:
await add_mcp_server(
payload=mcp_server_request,
user_api_key_dict=user_auth
)
# Verify the exception details
assert exc_info.value.status_code == 403
assert "permission" in str(exc_info.value.detail)
@pytest.mark.asyncio
async def test_create_mcp_server_invalid_alias():
"""
Test that creating an MCP server with a '-' in the alias fails with the correct error.
"""
with mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.MCP_AVAILABLE", True), \
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw") as mock_get_prisma, \
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.get_mcp_server") as mock_get_server, \
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.create_mcp_server") as mock_create:
from litellm.proxy.management_endpoints.mcp_management_endpoints import add_mcp_server
from fastapi import HTTPException
mock_prisma = mock.Mock()
mock_get_prisma.return_value = mock_prisma
# Set up test data with invalid alias
server_id = str(uuid.uuid4())
mcp_server_request = generate_mcpserver_create_request(server_id=server_id)
mcp_server_request.alias = "invalid-alias" # This should trigger the validation error
# Mock that server does not exist
mock_get_server.return_value = None
# Mock create_mcp_server to prevent 500 error (this should not be called due to validation)
mock_create.return_value = None
user_auth = UserAPIKeyAuth(
api_key=TEST_MASTER_KEY,
user_id="test-user",
user_role=LitellmUserRoles.PROXY_ADMIN
)
with pytest.raises(HTTPException) as exc_info:
await add_mcp_server(
payload=mcp_server_request,
user_api_key_dict=user_auth
)
assert exc_info.value.status_code == 400
assert "Server name cannot contain '-'. Use an alternative character instead Found: invalid-alias" in str(exc_info.value.detail)
def test_validate_mcp_server_name_direct():
"""
Test the validation function directly to ensure it works.
"""
from litellm.proxy._experimental.mcp_server.utils import validate_mcp_server_name
from fastapi import HTTPException
# Test that valid names pass
validate_mcp_server_name("valid_name")
validate_mcp_server_name("valid name")
# Test that invalid names with hyphens raise exceptions
with pytest.raises(Exception) as exc_info:
validate_mcp_server_name("invalid-name")
assert "cannot contain" in str(exc_info.value)
# Test that invalid names with hyphens raise HTTPException when requested
with pytest.raises(HTTPException) as exc_info:
validate_mcp_server_name("invalid-name", raise_http_exception=True)
assert exc_info.value.status_code == 400
assert "cannot contain" in str(exc_info.value.detail)

View File

@@ -0,0 +1,208 @@
import pytest
from openai import OpenAI, BadRequestError, AsyncOpenAI
import asyncio
import httpx
def generate_key_sync():
url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
with httpx.Client() as client:
response = client.post(
url,
headers=headers,
json={
"models": [
"gpt-4",
"text-embedding-ada-002",
"dall-e-2",
"fake-openai-endpoint-2",
"mistral-embed",
"non-existent-model",
],
},
)
response_text = response.text
print(response_text)
print()
if response.status_code != 200:
raise Exception(
f"Request did not return a 200 status code: {response.status_code}"
)
response_data = response.json()
return response_data["key"]
def test_chat_completion_bad_model():
key = generate_key_sync()
client = OpenAI(api_key=key, base_url="http://0.0.0.0:4000")
with pytest.raises(BadRequestError) as excinfo:
client.chat.completions.create(
model="non-existent-model", messages=[{"role": "user", "content": "Hello!"}]
)
print(f"Chat completion error: {excinfo.value}")
def test_completion_bad_model():
key = generate_key_sync()
client = OpenAI(api_key=key, base_url="http://0.0.0.0:4000")
with pytest.raises(BadRequestError) as excinfo:
client.completions.create(model="non-existent-model", prompt="Hello!")
print(f"Completion error: {excinfo.value}")
def test_embeddings_bad_model():
key = generate_key_sync()
client = OpenAI(api_key=key, base_url="http://0.0.0.0:4000")
with pytest.raises(BadRequestError) as excinfo:
client.embeddings.create(model="non-existent-model", input="Hello world")
print(f"Embeddings error: {excinfo.value}")
def test_images_bad_model():
key = generate_key_sync()
client = OpenAI(api_key=key, base_url="http://0.0.0.0:4000")
with pytest.raises(BadRequestError) as excinfo:
client.images.generate(
model="non-existent-model", prompt="A cute baby sea otter"
)
print(f"Images error: {excinfo.value}")
@pytest.mark.asyncio
async def test_async_chat_completion_bad_model():
key = generate_key_sync()
async_client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000")
with pytest.raises(BadRequestError) as excinfo:
await async_client.chat.completions.create(
model="non-existent-model", messages=[{"role": "user", "content": "Hello!"}]
)
print(f"Async chat completion error: {excinfo.value}")
@pytest.mark.parametrize(
"curl_command",
[
'curl http://0.0.0.0:4000/v1/chat/completions -H \'Content-Type: application/json\' -H \'Authorization: Bearer sk-1234\' -d \'{"messages":[{"role":"user","content":"Hello!"}]}\'',
"curl http://0.0.0.0:4000/v1/completions -H 'Content-Type: application/json' -H 'Authorization: Bearer sk-1234' -d '{\"prompt\":\"Hello!\"}'",
"curl http://0.0.0.0:4000/v1/embeddings -H 'Content-Type: application/json' -H 'Authorization: Bearer sk-1234' -d '{\"input\":\"Hello world\"}'",
"curl http://0.0.0.0:4000/v1/images/generations -H 'Content-Type: application/json' -H 'Authorization: Bearer sk-1234' -d '{\"prompt\":\"A cute baby sea otter\"}'",
],
ids=["chat", "completions", "embeddings", "images"],
)
def test_missing_model_parameter_curl(curl_command):
import subprocess
import json
# Run the curl command and capture the output
key = generate_key_sync()
curl_command = curl_command.replace("sk-1234", key)
result = subprocess.run(curl_command, shell=True, capture_output=True, text=True)
# Parse the JSON response
response = json.loads(result.stdout)
# Check that we got an error response
assert "error" in response
print("error in response", json.dumps(response, indent=4))
assert "litellm.BadRequestError" in response["error"]["message"]
@pytest.mark.asyncio
async def test_chat_completion_bad_model_with_spend_logs():
"""
Tests that Error Logs are created for failed requests
"""
import json
key = generate_key_sync()
# Use httpx to make the request and capture headers
url = "http://0.0.0.0:4000/v1/chat/completions"
headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"}
payload = {
"model": "non-existent-model",
"messages": [{"role": "user", "content": "Hello!"}],
}
with httpx.Client() as client:
response = client.post(url, headers=headers, json=payload)
# Extract the litellm call ID from headers
litellm_call_id = response.headers.get("x-litellm-call-id")
print(f"Status code: {response.status_code}")
print(f"Headers: {dict(response.headers)}")
print(f"LiteLLM Call ID: {litellm_call_id}")
# Parse the JSON response body
try:
response_body = response.json()
print(f"Error response: {json.dumps(response_body, indent=4)}")
except json.JSONDecodeError:
print(f"Could not parse response body as JSON: {response.text}")
assert (
litellm_call_id is not None
), "Failed to get LiteLLM Call ID from response headers"
print("waiting for flushing error log to db....")
await asyncio.sleep(15)
# Now query the spend logs
url = "http://0.0.0.0:4000/spend/logs?request_id=" + litellm_call_id
headers = {"Authorization": f"Bearer sk-1234", "Content-Type": "application/json"}
with httpx.Client() as client:
response = client.get(
url,
headers=headers,
)
assert (
response.status_code == 200
), f"Failed to get spend logs: {response.status_code}"
spend_logs = response.json()
# Print the spend logs payload
print(f"Spend logs response: {json.dumps(spend_logs, indent=4)}")
# Verify we have logs for the failed request
assert len(spend_logs) > 0, "No spend logs found"
# Check if the error is recorded in the logs
log_entry = spend_logs[0] # Should be the specific log for our litellm_call_id
# Verify the structure of the log entry
assert log_entry["request_id"] == litellm_call_id
assert log_entry["model"] == "non-existent-model"
assert log_entry["model_group"] == "non-existent-model"
assert log_entry["spend"] == 0.0
assert log_entry["total_tokens"] == 0
assert log_entry["prompt_tokens"] == 0
assert log_entry["completion_tokens"] == 0
# Verify metadata fields
assert log_entry["metadata"]["status"] == "failure"
assert "user_api_key" in log_entry["metadata"]
assert "error_information" in log_entry["metadata"]
# Verify error information
error_info = log_entry["metadata"]["error_information"]
assert "traceback" in error_info
assert error_info["error_code"] == "400"
assert error_info["error_class"] == "BadRequestError"
assert "litellm.BadRequestError" in error_info["error_message"]
assert "non-existent-model" in error_info["error_message"]
# Verify request details
assert log_entry["cache_hit"] == "False"
assert log_entry["response"] == {}

View File

@@ -0,0 +1,312 @@
import pytest
import asyncio
import aiohttp
import json
from openai import AsyncOpenAI
import uuid
from httpx import AsyncClient
import uuid
import os
TEST_MASTER_KEY = "sk-1234"
PROXY_BASE_URL = "http://0.0.0.0:4000"
@pytest.mark.asyncio
async def test_team_model_alias():
"""
Test model alias functionality with teams:
1. Add a new model with model_name="gpt-4-team1" and litellm_params.model="gpt-4o"
2. Create a new team
3. Update team with model_alias mapping
4. Generate key for team
5. Make request with aliased model name
"""
client = AsyncClient(base_url=PROXY_BASE_URL)
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
# Add new model
model_response = await client.post(
"/model/new",
json={
"model_name": "gpt-4o-team1",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
headers=headers,
)
assert model_response.status_code == 200
# Create new team
team_response = await client.post(
"/team/new",
json={
"models": ["gpt-4o-team1"],
},
headers=headers,
)
assert team_response.status_code == 200
team_data = team_response.json()
team_id = team_data["team_id"]
# Update team with model alias
update_response = await client.post(
"/team/update",
json={"team_id": team_id, "model_aliases": {"gpt-4o": "gpt-4o-team1"}},
headers=headers,
)
assert update_response.status_code == 200
# Generate key for team
key_response = await client.post(
"/key/generate", json={"team_id": team_id}, headers=headers
)
assert key_response.status_code == 200
key = key_response.json()["key"]
# Make request with model alias
openai_client = AsyncOpenAI(api_key=key, base_url=f"{PROXY_BASE_URL}/v1")
response = await openai_client.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": f"Test message {uuid.uuid4()}"}],
)
assert response is not None, "Should get valid response when using model alias"
# Cleanup - delete the model
model_id = model_response.json()["model_info"]["id"]
delete_response = await client.post(
"/model/delete",
json={"id": model_id},
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
)
assert delete_response.status_code == 200
@pytest.mark.asyncio
async def test_team_model_association():
"""
Test that models created with a team_id are properly associated with the team:
1. Create a new team
2. Add a model with team_id in model_info
3. Verify the model appears in team info
"""
client = AsyncClient(base_url=PROXY_BASE_URL)
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
# Create new team
team_response = await client.post(
"/team/new",
json={
"models": [], # Start with empty model list
},
headers=headers,
)
assert team_response.status_code == 200
team_data = team_response.json()
team_id = team_data["team_id"]
# Add new model with team_id
model_response = await client.post(
"/model/new",
json={
"model_name": "gpt-4-team-test",
"litellm_params": {
"model": "gpt-4",
"custom_llm_provider": "openai",
"api_key": "fake_key",
},
"model_info": {"team_id": team_id},
},
headers=headers,
)
assert model_response.status_code == 200
# Get team info and verify model association
team_info_response = await client.get(
f"/team/info",
headers=headers,
params={"team_id": team_id},
)
assert team_info_response.status_code == 200
team_info = team_info_response.json()["team_info"]
print("team_info", json.dumps(team_info, indent=4))
# Verify the model is in team_models
assert (
"gpt-4-team-test" in team_info["models"]
), "Model should be associated with team"
# Cleanup - delete the model
model_id = model_response.json()["model_info"]["id"]
delete_response = await client.post(
"/model/delete",
json={"id": model_id},
headers=headers,
)
assert delete_response.status_code == 200
@pytest.mark.asyncio
async def test_team_model_visibility_in_models_endpoint():
"""
Test that team-specific models are only visible to the correct team in /models endpoint:
1. Create two teams
2. Add a model associated with team1
3. Generate keys for both teams
4. Verify team1's key can see the model in /models
5. Verify team2's key cannot see the model in /models
"""
client = AsyncClient(base_url=PROXY_BASE_URL)
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
# Create team1
team1_response = await client.post(
"/team/new",
json={"models": []},
headers=headers,
)
assert team1_response.status_code == 200
team1_id = team1_response.json()["team_id"]
# Create team2
team2_response = await client.post(
"/team/new",
json={"models": []},
headers=headers,
)
assert team2_response.status_code == 200
team2_id = team2_response.json()["team_id"]
# Add model associated with team1
model_response = await client.post(
"/model/new",
json={
"model_name": "gpt-4-team-test",
"litellm_params": {
"model": "gpt-4",
"custom_llm_provider": "openai",
"api_key": "fake_key",
},
"model_info": {"team_id": team1_id},
},
headers=headers,
)
assert model_response.status_code == 200
# Generate keys for both teams
team1_key = (
await client.post("/key/generate", json={"team_id": team1_id}, headers=headers)
).json()["key"]
team2_key = (
await client.post("/key/generate", json={"team_id": team2_id}, headers=headers)
).json()["key"]
# Check models visibility for team1's key
team1_models = await client.get(
"/models", headers={"Authorization": f"Bearer {team1_key}"}
)
assert team1_models.status_code == 200
print("team1_models", json.dumps(team1_models.json(), indent=4))
assert any(
model["id"] == "gpt-4-team-test" for model in team1_models.json()["data"]
), "Team1 should see their model"
# Check models visibility for team2's key
team2_models = await client.get(
"/models", headers={"Authorization": f"Bearer {team2_key}"}
)
assert team2_models.status_code == 200
print("team2_models", json.dumps(team2_models.json(), indent=4))
assert not any(
model["id"] == "gpt-4-team-test" for model in team2_models.json()["data"]
), "Team2 should not see team1's model"
# Cleanup
model_id = model_response.json()["model_info"]["id"]
await client.post("/model/delete", json={"id": model_id}, headers=headers)
@pytest.mark.asyncio
async def test_team_model_visibility_in_model_info_endpoint():
"""
Test that team-specific models are visible to all users in /v2/model/info endpoint:
Note: /v2/model/info is used by the Admin UI to display model info
1. Create a team
2. Add a model associated with the team
3. Generate a team key
4. Verify both team key and non-team key can see the model in /v2/model/info
"""
client = AsyncClient(base_url=PROXY_BASE_URL)
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
# Create team
team_response = await client.post(
"/team/new",
json={"models": []},
headers=headers,
)
assert team_response.status_code == 200
team_id = team_response.json()["team_id"]
# Add model associated with team
model_response = await client.post(
"/model/new",
json={
"model_name": "gpt-4-team-test",
"litellm_params": {
"model": "gpt-4",
"custom_llm_provider": "openai",
"api_key": "fake_key",
},
"model_info": {"team_id": team_id},
},
headers=headers,
)
assert model_response.status_code == 200
# Generate team key
team_key = (
await client.post("/key/generate", json={"team_id": team_id}, headers=headers)
).json()["key"]
# Generate non-team key
non_team_key = (
await client.post("/key/generate", json={}, headers=headers)
).json()["key"]
# Check model info visibility with team key
team_model_info = await client.get(
"/v2/model/info",
headers={"Authorization": f"Bearer {team_key}"},
params={"model_name": "gpt-4-team-test"},
)
assert team_model_info.status_code == 200
team_model_info = team_model_info.json()
print("Team 1 model info", json.dumps(team_model_info, indent=4))
assert any(
model["model_info"].get("team_public_model_name") == "gpt-4-team-test"
for model in team_model_info["data"]
), "Team1 should see their model"
# Check model info visibility with non-team key
non_team_model_info = await client.get(
"/v2/model/info",
headers={"Authorization": f"Bearer {non_team_key}"},
params={"model_name": "gpt-4-team-test"},
)
assert non_team_model_info.status_code == 200
non_team_model_info = non_team_model_info.json()
print("Non-team model info", json.dumps(non_team_model_info, indent=4))
assert any(
model["model_info"].get("team_public_model_name") == "gpt-4-team-test"
for model in non_team_model_info["data"]
), "Non-team should see the model"
# Cleanup
model_id = model_response.json()["model_info"]["id"]
await client.post("/model/delete", json={"id": model_id}, headers=headers)