Added LiteLLM to the stack

This commit is contained in:
2025-08-18 09:40:50 +00:00
parent 0648c1968c
commit d220b04e32
2682 changed files with 533609 additions and 1 deletions

View File

@@ -0,0 +1,91 @@
"""
Simple unit tests for CustomOpenAPISpec class.
Tests basic functionality of OpenAPI schema generation.
"""
from unittest.mock import Mock, patch
import pytest
from litellm.proxy.common_utils.custom_openapi_spec import CustomOpenAPISpec
class TestCustomOpenAPISpec:
"""Test suite for CustomOpenAPISpec class."""
@pytest.fixture
def base_openapi_schema(self):
"""Base OpenAPI schema for testing."""
return {
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"paths": {
"/v1/chat/completions": {
"post": {
"summary": "Chat completions"
}
},
"/v1/embeddings": {
"post": {
"summary": "Embeddings"
}
},
"/v1/responses": {
"post": {
"summary": "Responses API"
}
}
}
}
@patch('litellm.proxy.common_utils.custom_openapi_spec.CustomOpenAPISpec.add_request_schema')
def test_add_chat_completion_request_schema(self, mock_add_schema, base_openapi_schema):
"""Test that chat completion schema is added correctly."""
mock_add_schema.return_value = base_openapi_schema
with patch('litellm.proxy._types.ProxyChatCompletionRequest') as mock_model:
result = CustomOpenAPISpec.add_chat_completion_request_schema(base_openapi_schema)
mock_add_schema.assert_called_once_with(
openapi_schema=base_openapi_schema,
model_class=mock_model,
schema_name="ProxyChatCompletionRequest",
paths=CustomOpenAPISpec.CHAT_COMPLETION_PATHS,
operation_name="chat completion"
)
assert result == base_openapi_schema
@patch('litellm.proxy.common_utils.custom_openapi_spec.CustomOpenAPISpec.add_request_schema')
def test_add_embedding_request_schema(self, mock_add_schema, base_openapi_schema):
"""Test that embedding schema is added correctly."""
mock_add_schema.return_value = base_openapi_schema
with patch('litellm.types.embedding.EmbeddingRequest') as mock_model:
result = CustomOpenAPISpec.add_embedding_request_schema(base_openapi_schema)
mock_add_schema.assert_called_once_with(
openapi_schema=base_openapi_schema,
model_class=mock_model,
schema_name="EmbeddingRequest",
paths=CustomOpenAPISpec.EMBEDDING_PATHS,
operation_name="embedding"
)
assert result == base_openapi_schema
@patch('litellm.proxy.common_utils.custom_openapi_spec.CustomOpenAPISpec.add_request_schema')
def test_add_responses_api_request_schema(self, mock_add_schema, base_openapi_schema):
"""Test that responses API schema is added correctly."""
mock_add_schema.return_value = base_openapi_schema
with patch('litellm.types.llms.openai.ResponsesAPIRequestParams') as mock_model:
result = CustomOpenAPISpec.add_responses_api_request_schema(base_openapi_schema)
mock_add_schema.assert_called_once_with(
openapi_schema=base_openapi_schema,
model_class=mock_model,
schema_name="ResponsesAPIRequestParams",
paths=CustomOpenAPISpec.RESPONSES_API_PATHS,
operation_name="responses API"
)
assert result == base_openapi_schema

View File

@@ -0,0 +1,222 @@
"""
Unit tests for GetRoutes utility class.
"""
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock
import pytest
from litellm.proxy.common_utils.get_routes import GetRoutes
class TestGetRoutes:
def test_get_app_routes_regular_route(self):
"""Test getting routes for a regular route with endpoint."""
# Mock a regular route
mock_route = Mock()
mock_route.path = "/test/endpoint"
mock_route.methods = ["GET", "POST"]
mock_route.name = "test_endpoint"
mock_route.endpoint = Mock()
# Mock endpoint function
mock_endpoint = Mock()
mock_endpoint.__name__ = "test_function"
result = GetRoutes.get_app_routes(mock_route, mock_endpoint)
assert len(result) == 1
assert result[0]["path"] == "/test/endpoint"
assert result[0]["methods"] == ["GET", "POST"]
assert result[0]["name"] == "test_endpoint"
assert result[0]["endpoint"] == "test_function"
def test_get_routes_for_mounted_app_regular_routes(self):
"""Test getting routes for mounted app with regular API routes."""
# Mock the main mount route
mock_mount_route = Mock()
mock_mount_route.path = "/mcp"
# Mock sub-app with regular routes
mock_sub_app = Mock()
mock_sub_app.routes = []
# Create a regular API route
mock_api_route = Mock()
mock_api_route.path = "/enabled"
mock_api_route.methods = ["GET"]
mock_api_route.name = "get_mcp_server_enabled"
# Mock endpoint function
mock_endpoint = Mock()
mock_endpoint.__name__ = "get_mcp_server_enabled"
mock_api_route.endpoint = mock_endpoint
mock_api_route.app = None # Regular route doesn't have app
mock_sub_app.routes.append(mock_api_route)
mock_mount_route.app = mock_sub_app
result = GetRoutes.get_routes_for_mounted_app(mock_mount_route)
assert len(result) == 1
assert result[0]["path"] == "/mcp/enabled"
assert result[0]["methods"] == ["GET"]
assert result[0]["name"] == "get_mcp_server_enabled"
assert result[0]["endpoint"] == "get_mcp_server_enabled"
assert result[0]["mounted_app"] is True
def test_get_routes_for_mounted_app_mount_objects(self):
"""Test getting routes for mounted app with Mount objects (the main fix)."""
# Mock the main mount route
mock_mount_route = Mock()
mock_mount_route.path = "/mcp"
# Mock sub-app
mock_sub_app = Mock()
mock_sub_app.routes = []
# Create Mount object for base MCP route (path='')
mock_mount_base = Mock(spec=['path', 'name', 'endpoint', 'app'])
mock_mount_base.path = ""
mock_mount_base.name = ""
mock_mount_base.endpoint = None # Mount objects don't have endpoint
# Mock app function
mock_app_function = Mock()
mock_app_function.__name__ = "handle_streamable_http_mcp"
mock_mount_base.app = mock_app_function
# Create Mount object for SSE route (path='/sse')
mock_mount_sse = Mock(spec=['path', 'name', 'endpoint', 'app'])
mock_mount_sse.path = "/sse"
mock_mount_sse.name = ""
mock_mount_sse.endpoint = None # Mount objects don't have endpoint
# Mock app function for SSE
mock_sse_function = Mock()
mock_sse_function.__name__ = "handle_sse_mcp"
mock_mount_sse.app = mock_sse_function
mock_sub_app.routes.extend([mock_mount_base, mock_mount_sse])
mock_mount_route.app = mock_sub_app
result = GetRoutes.get_routes_for_mounted_app(mock_mount_route)
# Should capture both /mcp and /mcp/sse routes
assert len(result) == 2
# Check base MCP route
base_route = next(r for r in result if r["path"] == "/mcp")
assert base_route["methods"] == ["GET", "POST"] # Default methods
assert base_route["endpoint"] == "handle_streamable_http_mcp"
assert base_route["mounted_app"] is True
# Check SSE route
sse_route = next(r for r in result if r["path"] == "/mcp/sse")
assert sse_route["methods"] == ["GET", "POST"] # Default methods
assert sse_route["endpoint"] == "handle_sse_mcp"
assert sse_route["mounted_app"] is True
def test_get_routes_for_mounted_app_mixed_routes(self):
"""Test getting routes for mounted app with both regular routes and Mount objects."""
# Mock the main mount route
mock_mount_route = Mock()
mock_mount_route.path = "/mcp"
# Mock sub-app
mock_sub_app = Mock()
mock_sub_app.routes = []
# Create a regular API route
mock_api_route = Mock()
mock_api_route.path = "/enabled"
mock_api_route.methods = ["GET"]
mock_api_route.name = "get_mcp_server_enabled"
mock_endpoint = Mock()
mock_endpoint.__name__ = "get_mcp_server_enabled"
mock_api_route.endpoint = mock_endpoint
mock_api_route.app = None
# Create Mount object
mock_mount_base = Mock(spec=['path', 'name', 'endpoint', 'app'])
mock_mount_base.path = ""
mock_mount_base.name = ""
mock_mount_base.endpoint = None
mock_app_function = Mock()
mock_app_function.__name__ = "handle_streamable_http_mcp"
mock_mount_base.app = mock_app_function
mock_sub_app.routes.extend([mock_api_route, mock_mount_base])
mock_mount_route.app = mock_sub_app
result = GetRoutes.get_routes_for_mounted_app(mock_mount_route)
# Should capture both the API route and the Mount object
assert len(result) == 2
# Check API route
api_route = next(r for r in result if r["path"] == "/mcp/enabled")
assert api_route["methods"] == ["GET"]
assert api_route["endpoint"] == "get_mcp_server_enabled"
# Check Mount object route
mount_route = next(r for r in result if r["path"] == "/mcp")
assert mount_route["endpoint"] == "handle_streamable_http_mcp"
assert mount_route["mounted_app"] is True
def test_get_routes_for_mounted_app_with_static_files(self):
"""
Test getting routes for mounted app with StaticFiles object (reproduces AttributeError bug).
This test reproduces the exact stacktrace scenario:
AttributeError: 'StaticFiles' object has no attribute '__name__'. Did you mean: '__ne__'?
The original bug occurred when the code tried to access endpoint_func.__name__
directly on a StaticFiles object. The fix uses _safe_get_endpoint_name() which
gracefully handles objects without __name__ by falling back to class name.
"""
# Mock the main mount route (e.g., /ui)
mock_mount_route = Mock()
mock_mount_route.path = "/ui"
# Mock sub-app with routes
mock_sub_app = Mock()
mock_sub_app.routes = []
# Create a mock StaticFiles route (this is the problematic case)
mock_static_route = Mock(spec=['path', 'name', 'endpoint', 'app'])
mock_static_route.path = ""
mock_static_route.name = "ui"
mock_static_route.endpoint = None
# Mock StaticFiles object - this is the key part that caused the AttributeError
# Real StaticFiles objects don't have __name__ attribute
# Create a mock that simulates StaticFiles behavior (no __name__ attribute)
class StaticFiles:
"""Mock class that simulates real StaticFiles without __name__ attribute"""
pass
mock_static_files = StaticFiles()
# Verify no __name__ attribute exists on the instance (reproduces bug condition)
assert not hasattr(mock_static_files, '__name__')
mock_static_route.app = mock_static_files
mock_sub_app.routes.append(mock_static_route)
mock_mount_route.app = mock_sub_app
# This should NOT raise AttributeError thanks to _safe_get_endpoint_name
# In the old code, this would fail with: 'StaticFiles' object has no attribute '__name__'
result = GetRoutes.get_routes_for_mounted_app(mock_mount_route)
# Should handle StaticFiles gracefully without throwing AttributeError
assert len(result) == 1
assert result[0]["path"] == "/ui"
assert result[0]["methods"] == ["GET", "POST"] # Default methods
assert result[0]["name"] == "ui"
# Should fall back to class name since instance doesn't have __name__ attribute
assert result[0]["endpoint"] == "StaticFiles" # Falls back to class name
assert result[0]["mounted_app"] is True

View File

@@ -0,0 +1,287 @@
import json
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import orjson
import pytest
from fastapi import Request
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body,
_safe_get_request_parsed_body,
_safe_set_request_parsed_body,
get_form_data,
)
from litellm.proxy._types import ProxyException
@pytest.mark.asyncio
async def test_request_body_caching():
"""
Test that the request body is cached after the first read and subsequent
calls use the cached version instead of parsing again.
"""
# Create a mock request with a JSON body
mock_request = MagicMock()
test_data = {"key": "value"}
# Use AsyncMock for the body method
mock_request.body = AsyncMock(return_value=orjson.dumps(test_data))
mock_request.headers = {"content-type": "application/json"}
mock_request.scope = {}
# First call should parse the body
result1 = await _read_request_body(mock_request)
assert result1 == test_data
assert "parsed_body" in mock_request.scope
assert mock_request.scope["parsed_body"] == (("key",), {"key": "value"})
# Verify the body was read once
mock_request.body.assert_called_once()
# Reset the mock to track the second call
mock_request.body.reset_mock()
# Second call should use the cached body
result2 = await _read_request_body(mock_request)
assert result2 == {"key": "value"}
# Verify the body was not read again
mock_request.body.assert_not_called()
@pytest.mark.asyncio
async def test_form_data_parsing():
"""
Test that form data is correctly parsed from the request.
"""
# Create a mock request with form data
mock_request = MagicMock()
test_data = {"name": "test_user", "message": "hello world"}
# Mock the form method to return the test data as an awaitable
mock_request.form = AsyncMock(return_value=test_data)
mock_request.headers = {"content-type": "application/x-www-form-urlencoded"}
mock_request.scope = {}
# Parse the form data
result = await _read_request_body(mock_request)
# Verify the form data was correctly parsed
assert result == test_data
assert "parsed_body" in mock_request.scope
assert mock_request.scope["parsed_body"] == (
("name", "message"),
{"name": "test_user", "message": "hello world"},
)
# Verify form() was called
mock_request.form.assert_called_once()
# The body method should not be called for form data
assert not hasattr(mock_request, "body") or not mock_request.body.called
@pytest.mark.asyncio
async def test_empty_request_body():
"""
Test handling of empty request bodies.
"""
# Create a mock request with an empty body
mock_request = MagicMock()
mock_request.body = AsyncMock(return_value=b"") # Empty bytes as an awaitable
mock_request.headers = {"content-type": "application/json"}
mock_request.scope = {}
# Parse the empty body
result = await _read_request_body(mock_request)
# Verify an empty dict is returned
assert result == {}
assert "parsed_body" in mock_request.scope
assert mock_request.scope["parsed_body"] == ((), {})
# Verify the body was read
mock_request.body.assert_called_once()
@pytest.mark.asyncio
async def test_circular_reference_handling():
"""
Test that cached request body isn't modified when the returned result is modified.
Demonstrates the mutable dictionary reference issue.
"""
# Create a mock request with initial data
mock_request = MagicMock()
initial_body = {
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
}
mock_request.body = AsyncMock(return_value=orjson.dumps(initial_body))
mock_request.headers = {"content-type": "application/json"}
mock_request.scope = {}
# First parse
result = await _read_request_body(mock_request)
# Verify initial parse
assert result["model"] == "gpt-4"
assert result["messages"] == [{"role": "user", "content": "Hello"}]
# Modify the result by adding proxy_server_request
result["proxy_server_request"] = {
"url": "http://0.0.0.0:4000/v1/chat/completions",
"method": "POST",
"headers": {"content-type": "application/json"},
"body": result, # Creates circular reference
}
# Second parse using the same request - will use the modified cached value
result2 = await _read_request_body(mock_request)
assert (
"proxy_server_request" not in result2
) # This will pass, showing the cache pollution
@pytest.mark.asyncio
async def test_json_parsing_error_handling():
"""
Test that JSON parsing errors are properly handled and raise ProxyException
with appropriate error messages.
"""
# Test case 1: Trailing comma error
mock_request = MagicMock()
invalid_json_with_trailing_comma = b'''{
"model": "gpt-4o",
"tools": [
{
"type": "mcp",
"server_label": "litellm",
"headers": {
"x-litellm-api-key": "Bearer sk-1234",
}
}
],
"input": "Run available tools"
}'''
mock_request.body = AsyncMock(return_value=invalid_json_with_trailing_comma)
mock_request.headers = {"content-type": "application/json"}
mock_request.scope = {}
# Should raise ProxyException for trailing comma
with pytest.raises(ProxyException) as exc_info:
await _read_request_body(mock_request)
assert exc_info.value.code == "400"
assert "Invalid JSON payload" in exc_info.value.message
assert "trailing comma" in exc_info.value.message
# Test case 2: Unquoted property name error
mock_request2 = MagicMock()
invalid_json_unquoted_property = b'''{
"model": "gpt-4o",
"tools": [
{
type: "mcp",
"server_label": "litellm"
}
],
"input": "Run available tools"
}'''
mock_request2.body = AsyncMock(return_value=invalid_json_unquoted_property)
mock_request2.headers = {"content-type": "application/json"}
mock_request2.scope = {}
# Should raise ProxyException for unquoted property
with pytest.raises(ProxyException) as exc_info2:
await _read_request_body(mock_request2)
assert exc_info2.value.code == "400"
assert "Invalid JSON payload" in exc_info2.value.message
# Test case 3: Valid JSON should work normally
mock_request3 = MagicMock()
valid_json = b'''{
"model": "gpt-4o",
"tools": [
{
"type": "mcp",
"server_label": "litellm",
"headers": {
"x-litellm-api-key": "Bearer sk-1234"
}
}
],
"input": "Run available tools"
}'''
mock_request3.body = AsyncMock(return_value=valid_json)
mock_request3.headers = {"content-type": "application/json"}
mock_request3.scope = {}
# Should parse successfully
result = await _read_request_body(mock_request3)
assert result["model"] == "gpt-4o"
assert result["input"] == "Run available tools"
assert len(result["tools"]) == 1
assert result["tools"][0]["type"] == "mcp"
@pytest.mark.asyncio
async def test_get_form_data():
"""
Test that get_form_data correctly handles form data with array notation.
Tests audio transcription parameters as a specific example.
"""
# Create a mock request with transcription form data
mock_request = MagicMock()
# Create mock form data with array notation for timestamp_granularities
mock_form_data = {
"file": "file_object", # In a real request this would be an UploadFile
"model": "gpt-4o-transcribe",
"include[]": "logprobs", # Array notation
"language": "en",
"prompt": "Transcribe this audio file",
"response_format": "json",
"stream": "false",
"temperature": "0.2",
"timestamp_granularities[]": "word", # First array item
"timestamp_granularities[]": "segment", # Second array item (would overwrite in dict, but handled by the function)
}
# Mock the form method to return the test data
mock_request.form = AsyncMock(return_value=mock_form_data)
# Call the function being tested
result = await get_form_data(mock_request)
# Verify regular form fields are preserved
assert result["file"] == "file_object"
assert result["model"] == "gpt-4o-transcribe"
assert result["language"] == "en"
assert result["prompt"] == "Transcribe this audio file"
assert result["response_format"] == "json"
assert result["stream"] == "false"
assert result["temperature"] == "0.2"
# Verify array fields are correctly parsed
assert "include" in result
assert isinstance(result["include"], list)
assert "logprobs" in result["include"]
assert "timestamp_granularities" in result
assert isinstance(result["timestamp_granularities"], list)
# Note: In a real MultiDict, both values would be present
# But in our mock dictionary the second value overwrites the first
assert "segment" in result["timestamp_granularities"]

View File

@@ -0,0 +1,90 @@
from unittest.mock import MagicMock, mock_open, patch
import pytest
import yaml
from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
class TestGetFileContentsFromS3:
"""Test suite for S3 config loading functionality."""
@patch('boto3.client')
@patch('litellm.main.bedrock_converse_chat_completion')
@patch('yaml.safe_load')
def test_get_file_contents_from_s3_no_temp_file_creation(
self, mock_yaml_load, mock_bedrock, mock_boto3_client
):
"""
Test that get_file_contents_from_s3 doesn't create temporary files
and uses yaml.safe_load directly on the S3 response content.
Note: It's critical that yaml.safe_load is used
Relevant issue/PR: https://github.com/BerriAI/litellm/pull/12078
"""
# Mock credentials
mock_credentials = MagicMock()
mock_credentials.access_key = "test_access_key"
mock_credentials.secret_key = "test_secret_key"
mock_credentials.token = "test_token"
mock_bedrock.get_credentials.return_value = mock_credentials
# Mock S3 client and response
mock_s3_client = MagicMock()
mock_boto3_client.return_value = mock_s3_client
# Mock S3 response with YAML content
yaml_content = """
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
"""
mock_response_body = MagicMock()
mock_response_body.read.return_value = yaml_content.encode('utf-8')
mock_s3_response = {
'Body': mock_response_body
}
mock_s3_client.get_object.return_value = mock_s3_response
# Mock yaml.safe_load to return parsed config
expected_config = {
'model_list': [{
'model_name': 'gpt-3.5-turbo',
'litellm_params': {
'model': 'gpt-3.5-turbo'
}
}]
}
mock_yaml_load.return_value = expected_config
# Call the function
bucket_name = "test-bucket"
object_key = "config.yaml"
result = get_file_contents_from_s3(bucket_name, object_key)
# Assertions
assert result == expected_config
# Verify S3 client was created with correct credentials
mock_boto3_client.assert_called_once_with(
"s3",
aws_access_key_id="test_access_key",
aws_secret_access_key="test_secret_key",
aws_session_token="test_token"
)
# Verify S3 get_object was called with correct parameters
mock_s3_client.get_object.assert_called_once_with(
Bucket=bucket_name,
Key=object_key
)
# Verify the response body was read and decoded
mock_response_body.read.assert_called_once()
# Verify yaml.safe_load was called with the decoded content
mock_yaml_load.assert_called_once_with(yaml_content)

View File

@@ -0,0 +1,87 @@
import pytest
from litellm.proxy.common_utils.openai_endpoint_utils import remove_sensitive_info_from_deployment
@pytest.mark.parametrize(
"model_config, expected_config",
[
# Test case 1: Empty litellm_params
(
{
"model_name": "test-model",
"litellm_params": {}
},
{
"model_name": "test-model",
"litellm_params": {}
}
),
# Test case 2: Full sensitive data removal, mixed secrets of azure, aws, gcp, and typical api_key
(
{
"model_name": "gpt-4",
"litellm_params": {
"model": "openai/gpt-4",
"api_key": "sk-sensitive-key-123",
"client_secret": "~v8Q4W:Zp9gJ-3sTqX5aB@LkR2mNfYdC",
"vertex_credentials": {"type": "service_account"},
"aws_access_key_id": "AKIA123456789",
"aws_secret_access_key": "secret-access-key",
"api_base": "https://api.openai.com/v1",
"temperature": 0.7
},
"model_info": {"id": "test-id"}
},
{
"model_name": "gpt-4",
"litellm_params": {
"model": "openai/gpt-4",
"api_base": "https://api.openai.com/v1",
"temperature": 0.7
},
"model_info": {"id": "test-id"}
}
),
# Test case 3: Partial sensitive data, api_key
(
{
"model_name": "claude-3",
"litellm_params": {
"model": "anthropic/claude-3",
"api_key": "sk-anthropic-key",
"temperature": 0.5
}
},
{
"model_name": "claude-3",
"litellm_params": {
"model": "anthropic/claude-3",
"temperature": 0.5
}
}
),
# Test case 4: No sensitive data
(
{
"model_name": "local-model",
"litellm_params": {
"model": "local/model",
"temperature": 0.8,
"max_tokens": 100
}
},
{
"model_name": "local-model",
"litellm_params": {
"model": "local/model",
"temperature": 0.8,
"max_tokens": 100
}
}
)
]
)
def test_remove_sensitive_info_from_deployment(model_config: dict, expected_config: dict):
sanitized_config = remove_sensitive_info_from_deployment(model_config)
assert sanitized_config == expected_config

View File

@@ -0,0 +1,322 @@
import asyncio
import os
import sys
import time
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List
import pytest
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
from litellm.proxy.utils import ProxyLogging
# Mock classes for testing
class MockLiteLLMTeamMembership:
async def update_many(
self, where: Dict[str, Any], data: Dict[str, Any]
) -> Dict[str, Any]:
# Mock the update_many method for litellm_teammembership
return {"count": 1}
class MockDB:
def __init__(self):
self.litellm_teammembership = MockLiteLLMTeamMembership()
class MockPrismaClient:
def __init__(self):
self.data: Dict[str, List[Any]] = {
"key": [],
"user": [],
"team": [],
"budget": [],
"enduser": [],
}
self.updated_data: Dict[str, List[Any]] = {
"key": [],
"user": [],
"team": [],
"budget": [],
"enduser": [],
}
self.db = MockDB()
async def get_data(self, table_name, query_type, **kwargs):
data = self.data.get(table_name, [])
# Handle specific filtering for budget table queries
if table_name == "budget" and query_type == "find_all" and "reset_at" in kwargs:
# Return budgets that need to be reset (simulate expired budgets)
return [item for item in data if hasattr(item, "budget_reset_at")]
# Handle specific filtering for enduser table queries
if (
table_name == "enduser"
and query_type == "find_all"
and "budget_id_list" in kwargs
):
budget_id_list = kwargs["budget_id_list"]
# Return endusers that match the budget IDs
return [
item
for item in data
if hasattr(item, "litellm_budget_table")
and hasattr(item.litellm_budget_table, "budget_id")
and item.litellm_budget_table.budget_id in budget_id_list
]
# Handle key queries with expires and reset_at
if (
table_name == "key"
and query_type == "find_all"
and ("expires" in kwargs or "reset_at" in kwargs)
):
return [item for item in data if hasattr(item, "budget_reset_at")]
return data
async def update_data(self, query_type, data_list, table_name):
self.updated_data[table_name] = data_list
return data_list
class MockProxyLogging:
class MockServiceLogging:
async def async_service_success_hook(self, **kwargs):
pass
async def async_service_failure_hook(self, **kwargs):
pass
def __init__(self):
self.service_logging_obj = self.MockServiceLogging()
# Test fixtures
@pytest.fixture
def mock_prisma_client():
return MockPrismaClient()
@pytest.fixture
def mock_proxy_logging():
return MockProxyLogging()
@pytest.fixture
def reset_budget_job(mock_prisma_client, mock_proxy_logging):
return ResetBudgetJob(
proxy_logging_obj=mock_proxy_logging, prisma_client=mock_prisma_client
)
# Helper function to run async tests
async def run_async_test(coro):
return await coro
# Tests
def test_reset_budget_for_key(reset_budget_job, mock_prisma_client):
# Setup test data with timezone-aware datetime
now = datetime.now(timezone.utc)
test_key = type(
"LiteLLM_VerificationToken",
(),
{
"spend": 100.0,
"budget_duration": "30d",
"budget_reset_at": now,
"id": "test-key-1",
},
)
mock_prisma_client.data["key"] = [test_key]
# Run the test
asyncio.run(reset_budget_job.reset_budget_for_litellm_keys())
# Verify results
assert len(mock_prisma_client.updated_data["key"]) == 1
updated_key = mock_prisma_client.updated_data["key"][0]
assert updated_key.spend == 0.0
assert updated_key.budget_reset_at > now
def test_reset_budget_for_user(reset_budget_job, mock_prisma_client):
# Setup test data with timezone-aware datetime
now = datetime.now(timezone.utc)
test_user = type(
"LiteLLM_UserTable",
(),
{
"spend": 200.0,
"budget_duration": "7d",
"budget_reset_at": now,
"id": "test-user-1",
},
)
mock_prisma_client.data["user"] = [test_user]
# Run the test
asyncio.run(reset_budget_job.reset_budget_for_litellm_users())
# Verify results
assert len(mock_prisma_client.updated_data["user"]) == 1
updated_user = mock_prisma_client.updated_data["user"][0]
assert updated_user.spend == 0.0
assert updated_user.budget_reset_at > now
def test_reset_budget_for_team(reset_budget_job, mock_prisma_client):
# Setup test data with timezone-aware datetime
now = datetime.now(timezone.utc)
test_team = type(
"LiteLLM_TeamTable",
(),
{
"spend": 500.0,
"budget_duration": "1mo",
"budget_reset_at": now,
"id": "test-team-1",
},
)
mock_prisma_client.data["team"] = [test_team]
# Run the test
asyncio.run(reset_budget_job.reset_budget_for_litellm_teams())
# Verify results
assert len(mock_prisma_client.updated_data["team"]) == 1
updated_team = mock_prisma_client.updated_data["team"][0]
assert updated_team.spend == 0.0
assert updated_team.budget_reset_at > now
def test_reset_budget_for_enduser(reset_budget_job, mock_prisma_client):
# Setup test data
now = datetime.now(timezone.utc)
test_budget = type(
"LiteLLM_BudgetTable",
(),
{
"max_budget": 500.0,
"budget_duration": "1d",
"budget_reset_at": now,
"budget_id": "test-budget-1",
},
)
test_enduser = type(
"LiteLLM_EndUserTable",
(),
{
"spend": 20.0,
"litellm_budget_table": test_budget,
"user_id": "test-enduser-1",
},
)
mock_prisma_client.data["budget"] = [test_budget]
mock_prisma_client.data["enduser"] = [test_enduser]
# Run the test
asyncio.run(reset_budget_job.reset_budget_for_litellm_budget_table())
# Verify results
assert len(mock_prisma_client.updated_data["enduser"]) == 1
assert len(mock_prisma_client.updated_data["budget"]) == 1
updated_enduser = mock_prisma_client.updated_data["enduser"][0]
updated_budget = mock_prisma_client.updated_data["budget"][0]
assert updated_enduser.spend == 0.0
assert updated_budget.budget_reset_at > now
def test_reset_budget_all(reset_budget_job, mock_prisma_client):
# Setup test data with timezone-aware datetime
now = datetime.now(timezone.utc)
# Create test objects for all three types
test_key = type(
"LiteLLM_VerificationToken",
(),
{
"spend": 100.0,
"budget_duration": "30d",
"budget_reset_at": now,
"id": "test-key-1",
},
)
test_user = type(
"LiteLLM_UserTable",
(),
{
"spend": 200.0,
"budget_duration": "7d",
"budget_reset_at": now,
"id": "test-user-1",
},
)
test_team = type(
"LiteLLM_TeamTable",
(),
{
"spend": 500.0,
"budget_duration": "1mo",
"budget_reset_at": now,
"id": "test-team-1",
},
)
test_budget = type(
"LiteLLM_BudgetTable",
(),
{
"max_budget": 500.0,
"budget_duration": "1d",
"budget_reset_at": now,
"budget_id": "test-budget-1",
},
)
test_enduser = type(
"LiteLLM_EndUserTable",
(),
{
"spend": 20.0,
"litellm_budget_table": test_budget,
"user_id": "test-enduser-1",
},
)
mock_prisma_client.data["key"] = [test_key]
mock_prisma_client.data["user"] = [test_user]
mock_prisma_client.data["team"] = [test_team]
mock_prisma_client.data["budget"] = [test_budget]
mock_prisma_client.data["enduser"] = [test_enduser]
# Run the test
asyncio.run(reset_budget_job.reset_budget())
# Verify results
assert len(mock_prisma_client.updated_data["key"]) == 1
assert len(mock_prisma_client.updated_data["user"]) == 1
assert len(mock_prisma_client.updated_data["team"]) == 1
assert len(mock_prisma_client.updated_data["enduser"]) == 1
assert len(mock_prisma_client.updated_data["budget"]) == 1
# Check that all spends were reset to 0
assert mock_prisma_client.updated_data["key"][0].spend == 0.0
assert mock_prisma_client.updated_data["user"][0].spend == 0.0
assert mock_prisma_client.updated_data["team"][0].spend == 0.0
assert mock_prisma_client.updated_data["enduser"][0].spend == 0.0

View File

@@ -0,0 +1,35 @@
import asyncio
import json
import os
import sys
import time
from datetime import datetime, timedelta, timezone
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
def test_get_budget_reset_time():
"""
Test that the budget reset time is set to the first of the next month
"""
# Get the current date
now = datetime.now(timezone.utc)
# Calculate expected reset date (first of next month)
if now.month == 12:
expected_month = 1
expected_year = now.year + 1
else:
expected_month = now.month + 1
expected_year = now.year
expected_reset_at = datetime(expected_year, expected_month, 1, tzinfo=timezone.utc)
# Verify budget_reset_at is set to first of next month
assert get_budget_reset_time(budget_duration="1mo") == expected_reset_at

View File

@@ -0,0 +1,318 @@
# tests/litellm/proxy/common_utils/test_upsert_budget_membership.py
import types
from unittest.mock import AsyncMock, MagicMock
import pytest
from litellm.proxy.management_endpoints.common_utils import (
_upsert_budget_and_membership,
)
# ---------------------------------------------------------------------------
# Fixtures: a fake Prisma transaction and a fake UserAPIKeyAuth object
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_tx():
"""
Builds an object that looks just enough like the Prisma tx you use
inside _upsert_budget_and_membership.
"""
# membership “table”
membership = MagicMock()
membership.update = AsyncMock()
membership.upsert = AsyncMock()
# budget “table”
budget = MagicMock()
budget.update = AsyncMock()
# budget.create returns a fake row that has .budget_id
budget.create = AsyncMock(
return_value=types.SimpleNamespace(budget_id="new-budget-123")
)
tx = MagicMock()
tx.litellm_teammembership = membership
tx.litellm_budgettable = budget
return tx
@pytest.fixture
def fake_user():
"""Cheap stand-in for UserAPIKeyAuth."""
return types.SimpleNamespace(user_id="tester@example.com")
# TEST: max_budget is None, disconnect only
@pytest.mark.asyncio
async def test_upsert_disconnect(mock_tx, fake_user):
await _upsert_budget_and_membership(
mock_tx,
team_id="team-1",
user_id="user-1",
max_budget=None,
existing_budget_id=None,
user_api_key_dict=fake_user,
)
mock_tx.litellm_teammembership.update.assert_awaited_once_with(
where={"user_id_team_id": {"user_id": "user-1", "team_id": "team-1"}},
data={"litellm_budget_table": {"disconnect": True}},
)
mock_tx.litellm_budgettable.update.assert_not_called()
mock_tx.litellm_budgettable.create.assert_not_called()
mock_tx.litellm_teammembership.upsert.assert_not_called()
# TEST: existing budget id, creates new budget (current behavior)
@pytest.mark.asyncio
async def test_upsert_with_existing_budget_id_creates_new(mock_tx, fake_user):
"""
Test that even when existing_budget_id is provided, the function creates a new budget.
This reflects the current implementation behavior.
"""
await _upsert_budget_and_membership(
mock_tx,
team_id="team-2",
user_id="user-2",
max_budget=42.0,
existing_budget_id="bud-999", # This parameter is currently unused
user_api_key_dict=fake_user,
)
# Should create a new budget, not update existing
mock_tx.litellm_budgettable.create.assert_awaited_once_with(
data={
"max_budget": 42.0,
"created_by": fake_user.user_id,
"updated_by": fake_user.user_id,
},
include={"team_membership": True},
)
# Should upsert team membership with the new budget ID
new_budget_id = mock_tx.litellm_budgettable.create.return_value.budget_id
mock_tx.litellm_teammembership.upsert.assert_awaited_once_with(
where={"user_id_team_id": {"user_id": "user-2", "team_id": "team-2"}},
data={
"create": {
"user_id": "user-2",
"team_id": "team-2",
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
},
"update": {
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
},
},
)
# Should NOT update existing budget
mock_tx.litellm_budgettable.update.assert_not_called()
mock_tx.litellm_teammembership.update.assert_not_called()
# TEST: create new budget and link membership
@pytest.mark.asyncio
async def test_upsert_create_and_link(mock_tx, fake_user):
await _upsert_budget_and_membership(
mock_tx,
team_id="team-3",
user_id="user-3",
max_budget=99.9,
existing_budget_id=None,
user_api_key_dict=fake_user,
)
mock_tx.litellm_budgettable.create.assert_awaited_once_with(
data={
"max_budget": 99.9,
"created_by": fake_user.user_id,
"updated_by": fake_user.user_id,
},
include={"team_membership": True},
)
# Budget ID returned by the mocked create()
bid = mock_tx.litellm_budgettable.create.return_value.budget_id
mock_tx.litellm_teammembership.upsert.assert_awaited_once_with(
where={"user_id_team_id": {"user_id": "user-3", "team_id": "team-3"}},
data={
"create": {
"user_id": "user-3",
"team_id": "team-3",
"litellm_budget_table": {"connect": {"budget_id": bid}},
},
"update": {
"litellm_budget_table": {"connect": {"budget_id": bid}},
},
},
)
mock_tx.litellm_teammembership.update.assert_not_called()
mock_tx.litellm_budgettable.update.assert_not_called()
# TEST: create new budget and link membership, then create another new budget
@pytest.mark.asyncio
async def test_upsert_create_then_create_another(mock_tx, fake_user):
"""
Test that multiple calls to _upsert_budget_and_membership create separate budgets,
reflecting the current implementation behavior.
"""
# FIRST CALL create new budget and link membership
await _upsert_budget_and_membership(
mock_tx,
team_id="team-42",
user_id="user-42",
max_budget=10.0,
existing_budget_id=None,
user_api_key_dict=fake_user,
)
# capture the budget id that create() returned
created_bid = mock_tx.litellm_budgettable.create.return_value.budget_id
# sanity: we really did the create + upsert path
mock_tx.litellm_budgettable.create.assert_awaited_once()
mock_tx.litellm_teammembership.upsert.assert_awaited_once()
# SECOND CALL reset call history and create another budget
mock_tx.litellm_budgettable.create.reset_mock()
mock_tx.litellm_teammembership.upsert.reset_mock()
mock_tx.litellm_budgettable.update.reset_mock()
# Set up a new budget ID for the second create call
mock_tx.litellm_budgettable.create.return_value = types.SimpleNamespace(budget_id="new-budget-456")
await _upsert_budget_and_membership(
mock_tx,
team_id="team-42",
user_id="user-42",
max_budget=25.0, # new limit
existing_budget_id=created_bid, # this is ignored in current implementation
user_api_key_dict=fake_user,
)
# Should create another new budget (not update existing)
mock_tx.litellm_budgettable.create.assert_awaited_once_with(
data={
"max_budget": 25.0,
"created_by": fake_user.user_id,
"updated_by": fake_user.user_id,
},
include={"team_membership": True},
)
# Should upsert team membership with the new budget ID
new_budget_id = mock_tx.litellm_budgettable.create.return_value.budget_id
mock_tx.litellm_teammembership.upsert.assert_awaited_once_with(
where={"user_id_team_id": {"user_id": "user-42", "team_id": "team-42"}},
data={
"create": {
"user_id": "user-42",
"team_id": "team-42",
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
},
"update": {
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
},
},
)
# Should NOT call update
mock_tx.litellm_budgettable.update.assert_not_called()
# TEST: update rpm_limit for member with existing budget_id
@pytest.mark.asyncio
async def test_upsert_rpm_limit_update_creates_new_budget(mock_tx, fake_user):
"""
Test that updating rpm_limit for a member with an existing budget_id
creates a new budget with the new rpm/tpm limits and assigns it to the user.
"""
existing_budget_id = "existing-budget-456"
await _upsert_budget_and_membership(
mock_tx,
team_id="team-rpm-test",
user_id="user-rpm-test",
max_budget=50.0,
existing_budget_id=existing_budget_id,
user_api_key_dict=fake_user,
tpm_limit=1000,
rpm_limit=100, # updating rpm_limit
)
# Should create a new budget with all the specified limits
mock_tx.litellm_budgettable.create.assert_awaited_once_with(
data={
"max_budget": 50.0,
"tpm_limit": 1000,
"rpm_limit": 100,
"created_by": fake_user.user_id,
"updated_by": fake_user.user_id,
},
include={"team_membership": True},
)
# Should NOT update the existing budget
mock_tx.litellm_budgettable.update.assert_not_called()
# Should upsert team membership with the new budget ID
new_budget_id = mock_tx.litellm_budgettable.create.return_value.budget_id
mock_tx.litellm_teammembership.upsert.assert_awaited_once_with(
where={"user_id_team_id": {"user_id": "user-rpm-test", "team_id": "team-rpm-test"}},
data={
"create": {
"user_id": "user-rpm-test",
"team_id": "team-rpm-test",
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
},
"update": {
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
},
},
)
# TEST: create new budget with only rpm_limit (no max_budget)
@pytest.mark.asyncio
async def test_upsert_rpm_only_creates_new_budget(mock_tx, fake_user):
"""
Test that setting only rpm_limit creates a new budget with just the rpm_limit.
"""
await _upsert_budget_and_membership(
mock_tx,
team_id="team-rpm-only",
user_id="user-rpm-only",
max_budget=None,
existing_budget_id=None,
user_api_key_dict=fake_user,
rpm_limit=50,
)
# Should create a new budget with only rpm_limit
mock_tx.litellm_budgettable.create.assert_awaited_once_with(
data={
"rpm_limit": 50,
"created_by": fake_user.user_id,
"updated_by": fake_user.user_id,
},
include={"team_membership": True},
)
# Should upsert team membership with the new budget ID
new_budget_id = mock_tx.litellm_budgettable.create.return_value.budget_id
mock_tx.litellm_teammembership.upsert.assert_awaited_once_with(
where={"user_id_team_id": {"user_id": "user-rpm-only", "team_id": "team-rpm-only"}},
data={
"create": {
"user_id": "user-rpm-only",
"team_id": "team-rpm-only",
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
},
"update": {
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
},
},
)