Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Test adding a pass through assemblyai model + api key + api base to the db
|
||||
wait 20 seconds
|
||||
make request
|
||||
|
||||
Cases to cover
|
||||
1. user points api base to <proxy-base>/assemblyai
|
||||
2. user points api base to <proxy-base>/asssemblyai/us
|
||||
3. user points api base to <proxy-base>/assemblyai/eu
|
||||
4. Bad API Key / credential - 401
|
||||
"""
|
||||
|
||||
import time
|
||||
import assemblyai as aai
|
||||
import pytest
|
||||
import httpx
|
||||
import os
|
||||
import json
|
||||
|
||||
TEST_MASTER_KEY = "sk-1234"
|
||||
PROXY_BASE_URL = "http://0.0.0.0:4000"
|
||||
US_BASE_URL = f"{PROXY_BASE_URL}/assemblyai"
|
||||
EU_BASE_URL = f"{PROXY_BASE_URL}/eu.assemblyai"
|
||||
ASSEMBLYAI_API_KEY_ENV_VAR = "TEST_SPECIAL_ASSEMBLYAI_API_KEY"
|
||||
|
||||
|
||||
def _delete_all_assemblyai_models_from_db():
|
||||
"""
|
||||
Delete all assemblyai models from the db
|
||||
"""
|
||||
print("Deleting all assemblyai models from the db.......")
|
||||
model_list_response = httpx.get(
|
||||
url=f"{PROXY_BASE_URL}/v2/model/info",
|
||||
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||
)
|
||||
response_data = model_list_response.json()
|
||||
print("model list response", json.dumps(response_data, indent=4, default=str))
|
||||
# Filter for only AssemblyAI models
|
||||
assemblyai_models = [
|
||||
model
|
||||
for model in response_data["data"]
|
||||
if model.get("litellm_params", {}).get("custom_llm_provider") == "assemblyai"
|
||||
]
|
||||
|
||||
for model in assemblyai_models:
|
||||
model_id = model["model_info"]["id"]
|
||||
httpx.post(
|
||||
url=f"{PROXY_BASE_URL}/model/delete",
|
||||
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||
json={"id": model_id},
|
||||
)
|
||||
print("Deleted all assemblyai models from the db")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_assemblyai_models():
|
||||
"""
|
||||
Fixture to clean up AssemblyAI models before and after each test
|
||||
"""
|
||||
# Clean up before test
|
||||
_delete_all_assemblyai_models_from_db()
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Clean up after test
|
||||
_delete_all_assemblyai_models_from_db()
|
||||
|
||||
|
||||
def test_e2e_assemblyai_passthrough():
|
||||
"""
|
||||
Test adding a pass through assemblyai model + api key + api base to the db
|
||||
wait 20 seconds
|
||||
make request
|
||||
"""
|
||||
add_assembly_ai_model_to_db(api_base="https://api.assemblyai.com")
|
||||
virtual_key = create_virtual_key()
|
||||
# make request
|
||||
make_assemblyai_basic_transcribe_request(
|
||||
virtual_key=virtual_key, assemblyai_base_url=US_BASE_URL
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def test_e2e_assemblyai_passthrough_eu():
|
||||
"""
|
||||
Test adding a pass through assemblyai model + api key + api base to the db
|
||||
wait 20 seconds
|
||||
make request
|
||||
"""
|
||||
add_assembly_ai_model_to_db(api_base="https://api.eu.assemblyai.com")
|
||||
virtual_key = create_virtual_key()
|
||||
# make request
|
||||
make_assemblyai_basic_transcribe_request(
|
||||
virtual_key=virtual_key, assemblyai_base_url=EU_BASE_URL
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def test_assemblyai_routes_with_bad_api_key():
|
||||
"""
|
||||
Test AssemblyAI endpoints with invalid API key to ensure proper error handling
|
||||
"""
|
||||
bad_api_key = "sk-12222"
|
||||
payload = {
|
||||
"audio_url": "https://assembly.ai/wildfires.mp3",
|
||||
"audio_end_at": 280,
|
||||
"audio_start_from": 10,
|
||||
"auto_chapters": True,
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {bad_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Test EU endpoint
|
||||
eu_response = httpx.post(
|
||||
f"{PROXY_BASE_URL}/eu.assemblyai/v2/transcript", headers=headers, json=payload
|
||||
)
|
||||
assert (
|
||||
eu_response.status_code == 401
|
||||
), f"Expected 401 unauthorized, got {eu_response.status_code}"
|
||||
|
||||
# Test US endpoint
|
||||
us_response = httpx.post(
|
||||
f"{PROXY_BASE_URL}/assemblyai/v2/transcript", headers=headers, json=payload
|
||||
)
|
||||
assert (
|
||||
us_response.status_code == 401
|
||||
), f"Expected 401 unauthorized, got {us_response.status_code}"
|
||||
|
||||
|
||||
def create_virtual_key():
|
||||
"""
|
||||
Create a virtual key
|
||||
"""
|
||||
response = httpx.post(
|
||||
url=f"{PROXY_BASE_URL}/key/generate",
|
||||
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||
json={},
|
||||
)
|
||||
print(response.json())
|
||||
return response.json()["key"]
|
||||
|
||||
|
||||
def add_assembly_ai_model_to_db(
|
||||
api_base: str,
|
||||
):
|
||||
"""
|
||||
Add the assemblyai model to the db - makes a http request to the /model/new endpoint on PROXY_BASE_URL
|
||||
"""
|
||||
print("assmbly ai api key", os.getenv(ASSEMBLYAI_API_KEY_ENV_VAR))
|
||||
response = httpx.post(
|
||||
url=f"{PROXY_BASE_URL}/model/new",
|
||||
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||
json={
|
||||
"model_name": "assemblyai/*",
|
||||
"litellm_params": {
|
||||
"model": "assemblyai/*",
|
||||
"custom_llm_provider": "assemblyai",
|
||||
"api_key": os.getenv(ASSEMBLYAI_API_KEY_ENV_VAR),
|
||||
"api_base": api_base,
|
||||
"use_in_pass_through": True,
|
||||
},
|
||||
"model_info": {},
|
||||
},
|
||||
)
|
||||
print(response.json())
|
||||
pass
|
||||
|
||||
|
||||
def make_assemblyai_basic_transcribe_request(
|
||||
virtual_key: str, assemblyai_base_url: str
|
||||
):
|
||||
print("making basic transcribe request to assemblyai passthrough")
|
||||
|
||||
# Replace with your API key
|
||||
aai.settings.api_key = f"Bearer {virtual_key}"
|
||||
aai.settings.base_url = assemblyai_base_url
|
||||
|
||||
# URL of the file to transcribe
|
||||
FILE_URL = "https://assembly.ai/wildfires.mp3"
|
||||
|
||||
# You can also transcribe a local file by passing in a file path
|
||||
# FILE_URL = './path/to/file.mp3'
|
||||
|
||||
transcriber = aai.Transcriber()
|
||||
transcript = transcriber.transcribe(FILE_URL)
|
||||
print(transcript)
|
||||
print(transcript.id)
|
||||
if transcript.id:
|
||||
transcript.delete_by_id(transcript.id)
|
||||
else:
|
||||
pytest.fail("Failed to get transcript id")
|
||||
|
||||
if transcript.status == aai.TranscriptStatus.error:
|
||||
print(transcript.error)
|
||||
pytest.fail(f"Failed to transcribe file error: {transcript.error}")
|
||||
else:
|
||||
print(transcript.text)
|
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
PROD TEST - DO NOT Delete this Test
|
||||
|
||||
e2e test for langfuse callback in DB
|
||||
- Add langfuse callback to DB - with /config/update
|
||||
- wait 20 seconds for the callback to be loaded into the instance
|
||||
- Make a /chat/completions request to the proxy
|
||||
- Check if the request is logged in Langfuse
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import dotenv
|
||||
from dotenv import load_dotenv
|
||||
import pytest
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# used for testing
|
||||
LANGFUSE_BASE_URL = "https://exampleopenaiendpoint-production-c715.up.railway.app"
|
||||
|
||||
|
||||
async def config_update(session, routing_strategy=None):
|
||||
url = "http://0.0.0.0:4000/config/update"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
print("routing_strategy: ", routing_strategy)
|
||||
data = {
|
||||
"litellm_settings": {"success_callback": ["langfuse"]},
|
||||
"environment_variables": {
|
||||
"LANGFUSE_PUBLIC_KEY": "any-public-key",
|
||||
"LANGFUSE_SECRET_KEY": "any-secret-key",
|
||||
"LANGFUSE_HOST": LANGFUSE_BASE_URL,
|
||||
},
|
||||
}
|
||||
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
status = response.status
|
||||
response_text = await response.text()
|
||||
|
||||
print(response_text)
|
||||
print("status: ", status)
|
||||
|
||||
if status != 200:
|
||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||
return await response.json()
|
||||
|
||||
|
||||
async def check_langfuse_request(response_id: str):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
url = f"{LANGFUSE_BASE_URL}/langfuse/trace/{response_id}"
|
||||
async with session.get(url) as response:
|
||||
response_json = await response.json()
|
||||
assert response.status == 200, f"Expected status 200, got {response.status}"
|
||||
assert (
|
||||
response_json["exists"] == True
|
||||
), f"Request {response_id} not found in Langfuse traces"
|
||||
assert response_json["request_id"] == response_id, f"Request ID mismatch"
|
||||
|
||||
|
||||
async def make_chat_completions_request() -> ChatCompletion:
|
||||
client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||
response = await client.chat.completions.create(
|
||||
model="fake-openai-endpoint",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
)
|
||||
print(response)
|
||||
return response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_e2e_langfuse_callbacks_in_db():
|
||||
|
||||
session = aiohttp.ClientSession()
|
||||
|
||||
# add langfuse callback to DB
|
||||
await config_update(session)
|
||||
|
||||
# wait 20 seconds for the callback to be loaded into the instance
|
||||
await asyncio.sleep(20)
|
||||
|
||||
# make a /chat/completions request to the proxy
|
||||
response = await make_chat_completions_request()
|
||||
print(response)
|
||||
response_id = response.id
|
||||
print("response_id: ", response_id)
|
||||
|
||||
await asyncio.sleep(11)
|
||||
# check if the request is logged in Langfuse
|
||||
await check_langfuse_request(response_id)
|
@@ -0,0 +1,325 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
import pytest
|
||||
import uuid
|
||||
import os
|
||||
import asyncio
|
||||
from unittest import mock
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
|
||||
from starlette import status
|
||||
|
||||
from litellm.constants import LITELLM_PROXY_ADMIN_NAME
|
||||
from litellm.proxy._types import MCPSpecVersion, MCPSpecVersionType, MCPTransportType, MCPTransport, NewMCPServerRequest, LiteLLM_MCPServerTable, LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.types.mcp import MCPAuth
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import does_mcp_server_exist
|
||||
|
||||
TEST_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-1234")
|
||||
|
||||
def generate_mcpserver_record(url: Optional[str] = None,
|
||||
transport: Optional[MCPTransportType] = None,
|
||||
spec_version: Optional[MCPSpecVersionType] = None) -> LiteLLM_MCPServerTable:
|
||||
"""
|
||||
Generate a mock record for testing.
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
return LiteLLM_MCPServerTable(
|
||||
server_id=str(uuid.uuid4()),
|
||||
alias="Test Server",
|
||||
url=url or "http://localhost.com:8080/mcp",
|
||||
transport=transport or MCPTransport.sse,
|
||||
spec_version=spec_version or MCPSpecVersion.mar_2025,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
# Cheers SO
|
||||
def is_valid_uuid(val):
|
||||
try:
|
||||
uuid.UUID(str(val))
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def generate_mcpserver_create_request(
|
||||
server_id: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
transport: Optional[MCPTransportType] = None,
|
||||
spec_version: Optional[MCPSpecVersionType] = None) -> NewMCPServerRequest:
|
||||
"""
|
||||
Generate a mock create request for testing.
|
||||
"""
|
||||
return NewMCPServerRequest(
|
||||
server_id=server_id,
|
||||
alias="Test Server",
|
||||
url=url or "http://localhost.com:8080/mcp",
|
||||
transport=transport or MCPTransport.sse,
|
||||
spec_version=spec_version or MCPSpecVersion.mar_2025,
|
||||
)
|
||||
|
||||
def assert_mcp_server_record_same(mcp_server: NewMCPServerRequest, resp: LiteLLM_MCPServerTable):
|
||||
"""
|
||||
Assert that the mcp server record is created correctly.
|
||||
"""
|
||||
if mcp_server.server_id is not None:
|
||||
assert resp.server_id == mcp_server.server_id
|
||||
else:
|
||||
assert is_valid_uuid(resp.server_id)
|
||||
assert resp.alias == mcp_server.alias
|
||||
assert resp.url == mcp_server.url
|
||||
assert resp.description == mcp_server.description
|
||||
assert resp.transport == mcp_server.transport
|
||||
assert resp.spec_version == mcp_server.spec_version
|
||||
assert resp.auth_type == mcp_server.auth_type
|
||||
assert resp.created_at is not None
|
||||
assert resp.updated_at is not None
|
||||
assert resp.created_by == LITELLM_PROXY_ADMIN_NAME
|
||||
assert resp.updated_by == LITELLM_PROXY_ADMIN_NAME
|
||||
|
||||
|
||||
def test_does_mcp_server_exist():
|
||||
"""
|
||||
Unit Test if the MCP server exists in the list.
|
||||
"""
|
||||
mcp_server_records: List[LiteLLM_MCPServerTable] = [generate_mcpserver_record(), generate_mcpserver_record()]
|
||||
# test all records are found
|
||||
for record in mcp_server_records:
|
||||
assert does_mcp_server_exist(mcp_server_records, record.server_id)
|
||||
|
||||
# test record not found
|
||||
not_found_record = str(uuid.uuid4())
|
||||
assert False == does_mcp_server_exist(mcp_server_records, not_found_record)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mcp_server_direct():
|
||||
"""
|
||||
Direct test of the MCP server creation logic without HTTP calls.
|
||||
"""
|
||||
# Mock the database functions directly
|
||||
with mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.MCP_AVAILABLE", True), \
|
||||
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw") as mock_get_prisma, \
|
||||
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.create_mcp_server", new_callable=mock.AsyncMock) as mock_create, \
|
||||
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.get_mcp_server", new_callable=mock.AsyncMock) as mock_get_server, \
|
||||
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager") as mock_manager:
|
||||
|
||||
# Import after mocking
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import add_mcp_server
|
||||
|
||||
# Mock database client
|
||||
mock_prisma = mock.Mock()
|
||||
mock_get_prisma.return_value = mock_prisma
|
||||
|
||||
# Mock server manager
|
||||
mock_manager.add_update_server = mock.Mock()
|
||||
mock_manager.reload_servers_from_database = mock.AsyncMock()
|
||||
|
||||
# Set up test data
|
||||
server_id = str(uuid.uuid4())
|
||||
mcp_server_request = generate_mcpserver_create_request(server_id=server_id)
|
||||
|
||||
# The function will normalize the alias by replacing spaces with underscores
|
||||
expected_alias = mcp_server_request.alias.replace(' ', '_') if mcp_server_request.alias else None
|
||||
|
||||
expected_response = LiteLLM_MCPServerTable(
|
||||
server_id=server_id,
|
||||
alias=expected_alias, # Use the normalized alias
|
||||
description=mcp_server_request.description,
|
||||
url=mcp_server_request.url,
|
||||
transport=mcp_server_request.transport,
|
||||
spec_version=mcp_server_request.spec_version,
|
||||
auth_type=mcp_server_request.auth_type,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
created_by=LITELLM_PROXY_ADMIN_NAME,
|
||||
updated_by=LITELLM_PROXY_ADMIN_NAME,
|
||||
teams=[]
|
||||
)
|
||||
|
||||
# Mock the database calls
|
||||
mock_get_server.return_value = None # Server doesn't exist yet
|
||||
# Set up async mock for create_mcp_server using AsyncMock
|
||||
mock_create.return_value = expected_response
|
||||
|
||||
# Create mock user auth
|
||||
user_auth = UserAPIKeyAuth(
|
||||
api_key=TEST_MASTER_KEY,
|
||||
user_id="test-user",
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN
|
||||
)
|
||||
|
||||
# Call the function directly
|
||||
result = await add_mcp_server(
|
||||
payload=mcp_server_request,
|
||||
user_api_key_dict=user_auth
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result.server_id == server_id
|
||||
assert result.alias == expected_alias # Check against normalized alias
|
||||
assert result.url == mcp_server_request.url
|
||||
assert result.transport == mcp_server_request.transport
|
||||
assert result.spec_version == mcp_server_request.spec_version
|
||||
|
||||
# Verify mocks were called
|
||||
mock_get_server.assert_called_once_with(mock_prisma, server_id)
|
||||
mock_create.assert_called_once()
|
||||
mock_manager.add_update_server.assert_called_once_with(expected_response)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_mcp_server():
|
||||
"""
|
||||
Test that creating a duplicate MCP server fails appropriately.
|
||||
"""
|
||||
# Mock the database functions directly
|
||||
with mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.MCP_AVAILABLE", True), \
|
||||
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw") as mock_get_prisma, \
|
||||
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.get_mcp_server", new_callable=mock.AsyncMock) as mock_get_server:
|
||||
|
||||
# Import after mocking
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import add_mcp_server
|
||||
from fastapi import HTTPException
|
||||
|
||||
# Mock database client
|
||||
mock_prisma = mock.Mock()
|
||||
mock_get_prisma.return_value = mock_prisma
|
||||
|
||||
# Set up test data
|
||||
server_id = str(uuid.uuid4())
|
||||
mcp_server_request = generate_mcpserver_create_request(server_id=server_id)
|
||||
|
||||
existing_server = LiteLLM_MCPServerTable(
|
||||
server_id=server_id,
|
||||
alias="Existing Server",
|
||||
url="http://existing.com",
|
||||
transport=MCPTransport.sse,
|
||||
spec_version=MCPSpecVersion.mar_2025,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
teams=[]
|
||||
)
|
||||
|
||||
# Mock that server already exists
|
||||
mock_get_server.return_value = existing_server
|
||||
|
||||
# Create mock user auth
|
||||
user_auth = UserAPIKeyAuth(
|
||||
api_key=TEST_MASTER_KEY,
|
||||
user_id="test-user",
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN
|
||||
)
|
||||
|
||||
# Expect HTTPException to be raised
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await add_mcp_server(
|
||||
payload=mcp_server_request,
|
||||
user_api_key_dict=user_auth
|
||||
)
|
||||
|
||||
# Verify the exception details
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "already exists" in str(exc_info.value.detail)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mcp_server_auth_failure():
|
||||
"""
|
||||
Test that non-admin users cannot create MCP servers.
|
||||
"""
|
||||
# Mock the database functions directly
|
||||
with mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.MCP_AVAILABLE", True), \
|
||||
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw") as mock_get_prisma:
|
||||
|
||||
# Import after mocking
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import add_mcp_server
|
||||
from fastapi import HTTPException
|
||||
|
||||
# Mock database client
|
||||
mock_prisma = mock.Mock()
|
||||
mock_get_prisma.return_value = mock_prisma
|
||||
|
||||
# Set up test data
|
||||
server_id = str(uuid.uuid4())
|
||||
mcp_server_request = generate_mcpserver_create_request(server_id=server_id)
|
||||
|
||||
# Create mock user auth without admin role
|
||||
user_auth = UserAPIKeyAuth(
|
||||
api_key=TEST_MASTER_KEY,
|
||||
user_id="test-user",
|
||||
user_role=LitellmUserRoles.INTERNAL_USER # Not an admin
|
||||
)
|
||||
|
||||
# Expect HTTPException to be raised
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await add_mcp_server(
|
||||
payload=mcp_server_request,
|
||||
user_api_key_dict=user_auth
|
||||
)
|
||||
|
||||
# Verify the exception details
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "permission" in str(exc_info.value.detail)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mcp_server_invalid_alias():
|
||||
"""
|
||||
Test that creating an MCP server with a '-' in the alias fails with the correct error.
|
||||
"""
|
||||
with mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.MCP_AVAILABLE", True), \
|
||||
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw") as mock_get_prisma, \
|
||||
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.get_mcp_server") as mock_get_server, \
|
||||
mock.patch("litellm.proxy.management_endpoints.mcp_management_endpoints.create_mcp_server") as mock_create:
|
||||
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import add_mcp_server
|
||||
from fastapi import HTTPException
|
||||
|
||||
mock_prisma = mock.Mock()
|
||||
mock_get_prisma.return_value = mock_prisma
|
||||
|
||||
# Set up test data with invalid alias
|
||||
server_id = str(uuid.uuid4())
|
||||
mcp_server_request = generate_mcpserver_create_request(server_id=server_id)
|
||||
mcp_server_request.alias = "invalid-alias" # This should trigger the validation error
|
||||
|
||||
# Mock that server does not exist
|
||||
mock_get_server.return_value = None
|
||||
|
||||
# Mock create_mcp_server to prevent 500 error (this should not be called due to validation)
|
||||
mock_create.return_value = None
|
||||
|
||||
user_auth = UserAPIKeyAuth(
|
||||
api_key=TEST_MASTER_KEY,
|
||||
user_id="test-user",
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await add_mcp_server(
|
||||
payload=mcp_server_request,
|
||||
user_api_key_dict=user_auth
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Server name cannot contain '-'. Use an alternative character instead Found: invalid-alias" in str(exc_info.value.detail)
|
||||
|
||||
def test_validate_mcp_server_name_direct():
|
||||
"""
|
||||
Test the validation function directly to ensure it works.
|
||||
"""
|
||||
from litellm.proxy._experimental.mcp_server.utils import validate_mcp_server_name
|
||||
from fastapi import HTTPException
|
||||
|
||||
# Test that valid names pass
|
||||
validate_mcp_server_name("valid_name")
|
||||
validate_mcp_server_name("valid name")
|
||||
|
||||
# Test that invalid names with hyphens raise exceptions
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
validate_mcp_server_name("invalid-name")
|
||||
assert "cannot contain" in str(exc_info.value)
|
||||
|
||||
# Test that invalid names with hyphens raise HTTPException when requested
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
validate_mcp_server_name("invalid-name", raise_http_exception=True)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "cannot contain" in str(exc_info.value.detail)
|
@@ -0,0 +1,208 @@
|
||||
import pytest
|
||||
from openai import OpenAI, BadRequestError, AsyncOpenAI
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
|
||||
def generate_key_sync():
|
||||
url = "http://0.0.0.0:4000/key/generate"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
|
||||
with httpx.Client() as client:
|
||||
response = client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json={
|
||||
"models": [
|
||||
"gpt-4",
|
||||
"text-embedding-ada-002",
|
||||
"dall-e-2",
|
||||
"fake-openai-endpoint-2",
|
||||
"mistral-embed",
|
||||
"non-existent-model",
|
||||
],
|
||||
},
|
||||
)
|
||||
response_text = response.text
|
||||
|
||||
print(response_text)
|
||||
print()
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Request did not return a 200 status code: {response.status_code}"
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
return response_data["key"]
|
||||
|
||||
|
||||
def test_chat_completion_bad_model():
|
||||
key = generate_key_sync()
|
||||
client = OpenAI(api_key=key, base_url="http://0.0.0.0:4000")
|
||||
|
||||
with pytest.raises(BadRequestError) as excinfo:
|
||||
client.chat.completions.create(
|
||||
model="non-existent-model", messages=[{"role": "user", "content": "Hello!"}]
|
||||
)
|
||||
print(f"Chat completion error: {excinfo.value}")
|
||||
|
||||
|
||||
def test_completion_bad_model():
|
||||
key = generate_key_sync()
|
||||
client = OpenAI(api_key=key, base_url="http://0.0.0.0:4000")
|
||||
|
||||
with pytest.raises(BadRequestError) as excinfo:
|
||||
client.completions.create(model="non-existent-model", prompt="Hello!")
|
||||
print(f"Completion error: {excinfo.value}")
|
||||
|
||||
|
||||
def test_embeddings_bad_model():
|
||||
key = generate_key_sync()
|
||||
client = OpenAI(api_key=key, base_url="http://0.0.0.0:4000")
|
||||
|
||||
with pytest.raises(BadRequestError) as excinfo:
|
||||
client.embeddings.create(model="non-existent-model", input="Hello world")
|
||||
print(f"Embeddings error: {excinfo.value}")
|
||||
|
||||
|
||||
def test_images_bad_model():
|
||||
key = generate_key_sync()
|
||||
client = OpenAI(api_key=key, base_url="http://0.0.0.0:4000")
|
||||
|
||||
with pytest.raises(BadRequestError) as excinfo:
|
||||
client.images.generate(
|
||||
model="non-existent-model", prompt="A cute baby sea otter"
|
||||
)
|
||||
print(f"Images error: {excinfo.value}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_completion_bad_model():
|
||||
key = generate_key_sync()
|
||||
async_client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000")
|
||||
|
||||
with pytest.raises(BadRequestError) as excinfo:
|
||||
await async_client.chat.completions.create(
|
||||
model="non-existent-model", messages=[{"role": "user", "content": "Hello!"}]
|
||||
)
|
||||
print(f"Async chat completion error: {excinfo.value}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"curl_command",
|
||||
[
|
||||
'curl http://0.0.0.0:4000/v1/chat/completions -H \'Content-Type: application/json\' -H \'Authorization: Bearer sk-1234\' -d \'{"messages":[{"role":"user","content":"Hello!"}]}\'',
|
||||
"curl http://0.0.0.0:4000/v1/completions -H 'Content-Type: application/json' -H 'Authorization: Bearer sk-1234' -d '{\"prompt\":\"Hello!\"}'",
|
||||
"curl http://0.0.0.0:4000/v1/embeddings -H 'Content-Type: application/json' -H 'Authorization: Bearer sk-1234' -d '{\"input\":\"Hello world\"}'",
|
||||
"curl http://0.0.0.0:4000/v1/images/generations -H 'Content-Type: application/json' -H 'Authorization: Bearer sk-1234' -d '{\"prompt\":\"A cute baby sea otter\"}'",
|
||||
],
|
||||
ids=["chat", "completions", "embeddings", "images"],
|
||||
)
|
||||
def test_missing_model_parameter_curl(curl_command):
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
# Run the curl command and capture the output
|
||||
key = generate_key_sync()
|
||||
curl_command = curl_command.replace("sk-1234", key)
|
||||
result = subprocess.run(curl_command, shell=True, capture_output=True, text=True)
|
||||
# Parse the JSON response
|
||||
response = json.loads(result.stdout)
|
||||
|
||||
# Check that we got an error response
|
||||
assert "error" in response
|
||||
print("error in response", json.dumps(response, indent=4))
|
||||
|
||||
assert "litellm.BadRequestError" in response["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_bad_model_with_spend_logs():
|
||||
"""
|
||||
Tests that Error Logs are created for failed requests
|
||||
"""
|
||||
import json
|
||||
|
||||
key = generate_key_sync()
|
||||
|
||||
# Use httpx to make the request and capture headers
|
||||
url = "http://0.0.0.0:4000/v1/chat/completions"
|
||||
headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"}
|
||||
payload = {
|
||||
"model": "non-existent-model",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
}
|
||||
|
||||
with httpx.Client() as client:
|
||||
response = client.post(url, headers=headers, json=payload)
|
||||
|
||||
# Extract the litellm call ID from headers
|
||||
litellm_call_id = response.headers.get("x-litellm-call-id")
|
||||
print(f"Status code: {response.status_code}")
|
||||
print(f"Headers: {dict(response.headers)}")
|
||||
print(f"LiteLLM Call ID: {litellm_call_id}")
|
||||
|
||||
# Parse the JSON response body
|
||||
try:
|
||||
response_body = response.json()
|
||||
print(f"Error response: {json.dumps(response_body, indent=4)}")
|
||||
except json.JSONDecodeError:
|
||||
print(f"Could not parse response body as JSON: {response.text}")
|
||||
|
||||
assert (
|
||||
litellm_call_id is not None
|
||||
), "Failed to get LiteLLM Call ID from response headers"
|
||||
print("waiting for flushing error log to db....")
|
||||
await asyncio.sleep(15)
|
||||
|
||||
# Now query the spend logs
|
||||
url = "http://0.0.0.0:4000/spend/logs?request_id=" + litellm_call_id
|
||||
headers = {"Authorization": f"Bearer sk-1234", "Content-Type": "application/json"}
|
||||
|
||||
with httpx.Client() as client:
|
||||
response = client.get(
|
||||
url,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert (
|
||||
response.status_code == 200
|
||||
), f"Failed to get spend logs: {response.status_code}"
|
||||
|
||||
spend_logs = response.json()
|
||||
|
||||
# Print the spend logs payload
|
||||
print(f"Spend logs response: {json.dumps(spend_logs, indent=4)}")
|
||||
|
||||
# Verify we have logs for the failed request
|
||||
assert len(spend_logs) > 0, "No spend logs found"
|
||||
|
||||
# Check if the error is recorded in the logs
|
||||
log_entry = spend_logs[0] # Should be the specific log for our litellm_call_id
|
||||
|
||||
# Verify the structure of the log entry
|
||||
assert log_entry["request_id"] == litellm_call_id
|
||||
assert log_entry["model"] == "non-existent-model"
|
||||
assert log_entry["model_group"] == "non-existent-model"
|
||||
assert log_entry["spend"] == 0.0
|
||||
assert log_entry["total_tokens"] == 0
|
||||
assert log_entry["prompt_tokens"] == 0
|
||||
assert log_entry["completion_tokens"] == 0
|
||||
|
||||
# Verify metadata fields
|
||||
assert log_entry["metadata"]["status"] == "failure"
|
||||
assert "user_api_key" in log_entry["metadata"]
|
||||
assert "error_information" in log_entry["metadata"]
|
||||
|
||||
# Verify error information
|
||||
error_info = log_entry["metadata"]["error_information"]
|
||||
assert "traceback" in error_info
|
||||
assert error_info["error_code"] == "400"
|
||||
assert error_info["error_class"] == "BadRequestError"
|
||||
assert "litellm.BadRequestError" in error_info["error_message"]
|
||||
assert "non-existent-model" in error_info["error_message"]
|
||||
|
||||
# Verify request details
|
||||
assert log_entry["cache_hit"] == "False"
|
||||
assert log_entry["response"] == {}
|
@@ -0,0 +1,312 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
from openai import AsyncOpenAI
|
||||
import uuid
|
||||
from httpx import AsyncClient
|
||||
import uuid
|
||||
import os
|
||||
|
||||
TEST_MASTER_KEY = "sk-1234"
|
||||
PROXY_BASE_URL = "http://0.0.0.0:4000"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_model_alias():
|
||||
"""
|
||||
Test model alias functionality with teams:
|
||||
1. Add a new model with model_name="gpt-4-team1" and litellm_params.model="gpt-4o"
|
||||
2. Create a new team
|
||||
3. Update team with model_alias mapping
|
||||
4. Generate key for team
|
||||
5. Make request with aliased model name
|
||||
"""
|
||||
client = AsyncClient(base_url=PROXY_BASE_URL)
|
||||
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
||||
|
||||
# Add new model
|
||||
model_response = await client.post(
|
||||
"/model/new",
|
||||
json={
|
||||
"model_name": "gpt-4o-team1",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert model_response.status_code == 200
|
||||
|
||||
# Create new team
|
||||
team_response = await client.post(
|
||||
"/team/new",
|
||||
json={
|
||||
"models": ["gpt-4o-team1"],
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert team_response.status_code == 200
|
||||
team_data = team_response.json()
|
||||
team_id = team_data["team_id"]
|
||||
|
||||
# Update team with model alias
|
||||
update_response = await client.post(
|
||||
"/team/update",
|
||||
json={"team_id": team_id, "model_aliases": {"gpt-4o": "gpt-4o-team1"}},
|
||||
headers=headers,
|
||||
)
|
||||
assert update_response.status_code == 200
|
||||
|
||||
# Generate key for team
|
||||
key_response = await client.post(
|
||||
"/key/generate", json={"team_id": team_id}, headers=headers
|
||||
)
|
||||
assert key_response.status_code == 200
|
||||
key = key_response.json()["key"]
|
||||
|
||||
# Make request with model alias
|
||||
openai_client = AsyncOpenAI(api_key=key, base_url=f"{PROXY_BASE_URL}/v1")
|
||||
|
||||
response = await openai_client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": f"Test message {uuid.uuid4()}"}],
|
||||
)
|
||||
|
||||
assert response is not None, "Should get valid response when using model alias"
|
||||
|
||||
# Cleanup - delete the model
|
||||
model_id = model_response.json()["model_info"]["id"]
|
||||
delete_response = await client.post(
|
||||
"/model/delete",
|
||||
json={"id": model_id},
|
||||
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||
)
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_model_association():
|
||||
"""
|
||||
Test that models created with a team_id are properly associated with the team:
|
||||
1. Create a new team
|
||||
2. Add a model with team_id in model_info
|
||||
3. Verify the model appears in team info
|
||||
"""
|
||||
client = AsyncClient(base_url=PROXY_BASE_URL)
|
||||
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
||||
|
||||
# Create new team
|
||||
team_response = await client.post(
|
||||
"/team/new",
|
||||
json={
|
||||
"models": [], # Start with empty model list
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert team_response.status_code == 200
|
||||
team_data = team_response.json()
|
||||
team_id = team_data["team_id"]
|
||||
|
||||
# Add new model with team_id
|
||||
model_response = await client.post(
|
||||
"/model/new",
|
||||
json={
|
||||
"model_name": "gpt-4-team-test",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4",
|
||||
"custom_llm_provider": "openai",
|
||||
"api_key": "fake_key",
|
||||
},
|
||||
"model_info": {"team_id": team_id},
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert model_response.status_code == 200
|
||||
|
||||
# Get team info and verify model association
|
||||
team_info_response = await client.get(
|
||||
f"/team/info",
|
||||
headers=headers,
|
||||
params={"team_id": team_id},
|
||||
)
|
||||
assert team_info_response.status_code == 200
|
||||
team_info = team_info_response.json()["team_info"]
|
||||
|
||||
print("team_info", json.dumps(team_info, indent=4))
|
||||
|
||||
# Verify the model is in team_models
|
||||
assert (
|
||||
"gpt-4-team-test" in team_info["models"]
|
||||
), "Model should be associated with team"
|
||||
|
||||
# Cleanup - delete the model
|
||||
model_id = model_response.json()["model_info"]["id"]
|
||||
delete_response = await client.post(
|
||||
"/model/delete",
|
||||
json={"id": model_id},
|
||||
headers=headers,
|
||||
)
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_model_visibility_in_models_endpoint():
|
||||
"""
|
||||
Test that team-specific models are only visible to the correct team in /models endpoint:
|
||||
1. Create two teams
|
||||
2. Add a model associated with team1
|
||||
3. Generate keys for both teams
|
||||
4. Verify team1's key can see the model in /models
|
||||
5. Verify team2's key cannot see the model in /models
|
||||
"""
|
||||
client = AsyncClient(base_url=PROXY_BASE_URL)
|
||||
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
||||
|
||||
# Create team1
|
||||
team1_response = await client.post(
|
||||
"/team/new",
|
||||
json={"models": []},
|
||||
headers=headers,
|
||||
)
|
||||
assert team1_response.status_code == 200
|
||||
team1_id = team1_response.json()["team_id"]
|
||||
|
||||
# Create team2
|
||||
team2_response = await client.post(
|
||||
"/team/new",
|
||||
json={"models": []},
|
||||
headers=headers,
|
||||
)
|
||||
assert team2_response.status_code == 200
|
||||
team2_id = team2_response.json()["team_id"]
|
||||
|
||||
# Add model associated with team1
|
||||
model_response = await client.post(
|
||||
"/model/new",
|
||||
json={
|
||||
"model_name": "gpt-4-team-test",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4",
|
||||
"custom_llm_provider": "openai",
|
||||
"api_key": "fake_key",
|
||||
},
|
||||
"model_info": {"team_id": team1_id},
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert model_response.status_code == 200
|
||||
|
||||
# Generate keys for both teams
|
||||
team1_key = (
|
||||
await client.post("/key/generate", json={"team_id": team1_id}, headers=headers)
|
||||
).json()["key"]
|
||||
team2_key = (
|
||||
await client.post("/key/generate", json={"team_id": team2_id}, headers=headers)
|
||||
).json()["key"]
|
||||
|
||||
# Check models visibility for team1's key
|
||||
team1_models = await client.get(
|
||||
"/models", headers={"Authorization": f"Bearer {team1_key}"}
|
||||
)
|
||||
assert team1_models.status_code == 200
|
||||
print("team1_models", json.dumps(team1_models.json(), indent=4))
|
||||
assert any(
|
||||
model["id"] == "gpt-4-team-test" for model in team1_models.json()["data"]
|
||||
), "Team1 should see their model"
|
||||
|
||||
# Check models visibility for team2's key
|
||||
team2_models = await client.get(
|
||||
"/models", headers={"Authorization": f"Bearer {team2_key}"}
|
||||
)
|
||||
assert team2_models.status_code == 200
|
||||
print("team2_models", json.dumps(team2_models.json(), indent=4))
|
||||
assert not any(
|
||||
model["id"] == "gpt-4-team-test" for model in team2_models.json()["data"]
|
||||
), "Team2 should not see team1's model"
|
||||
|
||||
# Cleanup
|
||||
model_id = model_response.json()["model_info"]["id"]
|
||||
await client.post("/model/delete", json={"id": model_id}, headers=headers)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_model_visibility_in_model_info_endpoint():
|
||||
"""
|
||||
Test that team-specific models are visible to all users in /v2/model/info endpoint:
|
||||
Note: /v2/model/info is used by the Admin UI to display model info
|
||||
1. Create a team
|
||||
2. Add a model associated with the team
|
||||
3. Generate a team key
|
||||
4. Verify both team key and non-team key can see the model in /v2/model/info
|
||||
"""
|
||||
client = AsyncClient(base_url=PROXY_BASE_URL)
|
||||
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
||||
|
||||
# Create team
|
||||
team_response = await client.post(
|
||||
"/team/new",
|
||||
json={"models": []},
|
||||
headers=headers,
|
||||
)
|
||||
assert team_response.status_code == 200
|
||||
team_id = team_response.json()["team_id"]
|
||||
|
||||
# Add model associated with team
|
||||
model_response = await client.post(
|
||||
"/model/new",
|
||||
json={
|
||||
"model_name": "gpt-4-team-test",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4",
|
||||
"custom_llm_provider": "openai",
|
||||
"api_key": "fake_key",
|
||||
},
|
||||
"model_info": {"team_id": team_id},
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert model_response.status_code == 200
|
||||
|
||||
# Generate team key
|
||||
team_key = (
|
||||
await client.post("/key/generate", json={"team_id": team_id}, headers=headers)
|
||||
).json()["key"]
|
||||
|
||||
# Generate non-team key
|
||||
non_team_key = (
|
||||
await client.post("/key/generate", json={}, headers=headers)
|
||||
).json()["key"]
|
||||
|
||||
# Check model info visibility with team key
|
||||
team_model_info = await client.get(
|
||||
"/v2/model/info",
|
||||
headers={"Authorization": f"Bearer {team_key}"},
|
||||
params={"model_name": "gpt-4-team-test"},
|
||||
)
|
||||
assert team_model_info.status_code == 200
|
||||
team_model_info = team_model_info.json()
|
||||
print("Team 1 model info", json.dumps(team_model_info, indent=4))
|
||||
assert any(
|
||||
model["model_info"].get("team_public_model_name") == "gpt-4-team-test"
|
||||
for model in team_model_info["data"]
|
||||
), "Team1 should see their model"
|
||||
|
||||
# Check model info visibility with non-team key
|
||||
non_team_model_info = await client.get(
|
||||
"/v2/model/info",
|
||||
headers={"Authorization": f"Bearer {non_team_key}"},
|
||||
params={"model_name": "gpt-4-team-test"},
|
||||
)
|
||||
assert non_team_model_info.status_code == 200
|
||||
non_team_model_info = non_team_model_info.json()
|
||||
print("Non-team model info", json.dumps(non_team_model_info, indent=4))
|
||||
assert any(
|
||||
model["model_info"].get("team_public_model_name") == "gpt-4-team-test"
|
||||
for model in non_team_model_info["data"]
|
||||
), "Non-team should see the model"
|
||||
|
||||
# Cleanup
|
||||
model_id = model_response.json()["model_info"]["id"]
|
||||
await client.post("/model/delete", json={"id": model_id}, headers=headers)
|
Reference in New Issue
Block a user