Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Simple unit tests for CustomOpenAPISpec class.
|
||||
|
||||
Tests basic functionality of OpenAPI schema generation.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.proxy.common_utils.custom_openapi_spec import CustomOpenAPISpec
|
||||
|
||||
|
||||
class TestCustomOpenAPISpec:
|
||||
"""Test suite for CustomOpenAPISpec class."""
|
||||
|
||||
@pytest.fixture
|
||||
def base_openapi_schema(self):
|
||||
"""Base OpenAPI schema for testing."""
|
||||
return {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Test API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/v1/chat/completions": {
|
||||
"post": {
|
||||
"summary": "Chat completions"
|
||||
}
|
||||
},
|
||||
"/v1/embeddings": {
|
||||
"post": {
|
||||
"summary": "Embeddings"
|
||||
}
|
||||
},
|
||||
"/v1/responses": {
|
||||
"post": {
|
||||
"summary": "Responses API"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@patch('litellm.proxy.common_utils.custom_openapi_spec.CustomOpenAPISpec.add_request_schema')
|
||||
def test_add_chat_completion_request_schema(self, mock_add_schema, base_openapi_schema):
|
||||
"""Test that chat completion schema is added correctly."""
|
||||
mock_add_schema.return_value = base_openapi_schema
|
||||
|
||||
with patch('litellm.proxy._types.ProxyChatCompletionRequest') as mock_model:
|
||||
result = CustomOpenAPISpec.add_chat_completion_request_schema(base_openapi_schema)
|
||||
|
||||
mock_add_schema.assert_called_once_with(
|
||||
openapi_schema=base_openapi_schema,
|
||||
model_class=mock_model,
|
||||
schema_name="ProxyChatCompletionRequest",
|
||||
paths=CustomOpenAPISpec.CHAT_COMPLETION_PATHS,
|
||||
operation_name="chat completion"
|
||||
)
|
||||
assert result == base_openapi_schema
|
||||
|
||||
@patch('litellm.proxy.common_utils.custom_openapi_spec.CustomOpenAPISpec.add_request_schema')
|
||||
def test_add_embedding_request_schema(self, mock_add_schema, base_openapi_schema):
|
||||
"""Test that embedding schema is added correctly."""
|
||||
mock_add_schema.return_value = base_openapi_schema
|
||||
|
||||
with patch('litellm.types.embedding.EmbeddingRequest') as mock_model:
|
||||
result = CustomOpenAPISpec.add_embedding_request_schema(base_openapi_schema)
|
||||
|
||||
mock_add_schema.assert_called_once_with(
|
||||
openapi_schema=base_openapi_schema,
|
||||
model_class=mock_model,
|
||||
schema_name="EmbeddingRequest",
|
||||
paths=CustomOpenAPISpec.EMBEDDING_PATHS,
|
||||
operation_name="embedding"
|
||||
)
|
||||
assert result == base_openapi_schema
|
||||
|
||||
@patch('litellm.proxy.common_utils.custom_openapi_spec.CustomOpenAPISpec.add_request_schema')
|
||||
def test_add_responses_api_request_schema(self, mock_add_schema, base_openapi_schema):
|
||||
"""Test that responses API schema is added correctly."""
|
||||
mock_add_schema.return_value = base_openapi_schema
|
||||
|
||||
with patch('litellm.types.llms.openai.ResponsesAPIRequestParams') as mock_model:
|
||||
result = CustomOpenAPISpec.add_responses_api_request_schema(base_openapi_schema)
|
||||
|
||||
mock_add_schema.assert_called_once_with(
|
||||
openapi_schema=base_openapi_schema,
|
||||
model_class=mock_model,
|
||||
schema_name="ResponsesAPIRequestParams",
|
||||
paths=CustomOpenAPISpec.RESPONSES_API_PATHS,
|
||||
operation_name="responses API"
|
||||
)
|
||||
assert result == base_openapi_schema
|
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
Unit tests for GetRoutes utility class.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.proxy.common_utils.get_routes import GetRoutes
|
||||
|
||||
|
||||
class TestGetRoutes:
|
||||
|
||||
def test_get_app_routes_regular_route(self):
|
||||
"""Test getting routes for a regular route with endpoint."""
|
||||
# Mock a regular route
|
||||
mock_route = Mock()
|
||||
mock_route.path = "/test/endpoint"
|
||||
mock_route.methods = ["GET", "POST"]
|
||||
mock_route.name = "test_endpoint"
|
||||
mock_route.endpoint = Mock()
|
||||
|
||||
# Mock endpoint function
|
||||
mock_endpoint = Mock()
|
||||
mock_endpoint.__name__ = "test_function"
|
||||
|
||||
result = GetRoutes.get_app_routes(mock_route, mock_endpoint)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["path"] == "/test/endpoint"
|
||||
assert result[0]["methods"] == ["GET", "POST"]
|
||||
assert result[0]["name"] == "test_endpoint"
|
||||
assert result[0]["endpoint"] == "test_function"
|
||||
|
||||
def test_get_routes_for_mounted_app_regular_routes(self):
|
||||
"""Test getting routes for mounted app with regular API routes."""
|
||||
# Mock the main mount route
|
||||
mock_mount_route = Mock()
|
||||
mock_mount_route.path = "/mcp"
|
||||
|
||||
# Mock sub-app with regular routes
|
||||
mock_sub_app = Mock()
|
||||
mock_sub_app.routes = []
|
||||
|
||||
# Create a regular API route
|
||||
mock_api_route = Mock()
|
||||
mock_api_route.path = "/enabled"
|
||||
mock_api_route.methods = ["GET"]
|
||||
mock_api_route.name = "get_mcp_server_enabled"
|
||||
|
||||
# Mock endpoint function
|
||||
mock_endpoint = Mock()
|
||||
mock_endpoint.__name__ = "get_mcp_server_enabled"
|
||||
mock_api_route.endpoint = mock_endpoint
|
||||
mock_api_route.app = None # Regular route doesn't have app
|
||||
|
||||
mock_sub_app.routes.append(mock_api_route)
|
||||
mock_mount_route.app = mock_sub_app
|
||||
|
||||
result = GetRoutes.get_routes_for_mounted_app(mock_mount_route)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["path"] == "/mcp/enabled"
|
||||
assert result[0]["methods"] == ["GET"]
|
||||
assert result[0]["name"] == "get_mcp_server_enabled"
|
||||
assert result[0]["endpoint"] == "get_mcp_server_enabled"
|
||||
assert result[0]["mounted_app"] is True
|
||||
|
||||
def test_get_routes_for_mounted_app_mount_objects(self):
|
||||
"""Test getting routes for mounted app with Mount objects (the main fix)."""
|
||||
# Mock the main mount route
|
||||
mock_mount_route = Mock()
|
||||
mock_mount_route.path = "/mcp"
|
||||
|
||||
# Mock sub-app
|
||||
mock_sub_app = Mock()
|
||||
mock_sub_app.routes = []
|
||||
|
||||
# Create Mount object for base MCP route (path='')
|
||||
mock_mount_base = Mock(spec=['path', 'name', 'endpoint', 'app'])
|
||||
mock_mount_base.path = ""
|
||||
mock_mount_base.name = ""
|
||||
mock_mount_base.endpoint = None # Mount objects don't have endpoint
|
||||
|
||||
# Mock app function
|
||||
mock_app_function = Mock()
|
||||
mock_app_function.__name__ = "handle_streamable_http_mcp"
|
||||
mock_mount_base.app = mock_app_function
|
||||
|
||||
# Create Mount object for SSE route (path='/sse')
|
||||
mock_mount_sse = Mock(spec=['path', 'name', 'endpoint', 'app'])
|
||||
mock_mount_sse.path = "/sse"
|
||||
mock_mount_sse.name = ""
|
||||
mock_mount_sse.endpoint = None # Mount objects don't have endpoint
|
||||
|
||||
# Mock app function for SSE
|
||||
mock_sse_function = Mock()
|
||||
mock_sse_function.__name__ = "handle_sse_mcp"
|
||||
mock_mount_sse.app = mock_sse_function
|
||||
|
||||
mock_sub_app.routes.extend([mock_mount_base, mock_mount_sse])
|
||||
mock_mount_route.app = mock_sub_app
|
||||
|
||||
result = GetRoutes.get_routes_for_mounted_app(mock_mount_route)
|
||||
|
||||
# Should capture both /mcp and /mcp/sse routes
|
||||
assert len(result) == 2
|
||||
|
||||
# Check base MCP route
|
||||
base_route = next(r for r in result if r["path"] == "/mcp")
|
||||
assert base_route["methods"] == ["GET", "POST"] # Default methods
|
||||
assert base_route["endpoint"] == "handle_streamable_http_mcp"
|
||||
assert base_route["mounted_app"] is True
|
||||
|
||||
# Check SSE route
|
||||
sse_route = next(r for r in result if r["path"] == "/mcp/sse")
|
||||
assert sse_route["methods"] == ["GET", "POST"] # Default methods
|
||||
assert sse_route["endpoint"] == "handle_sse_mcp"
|
||||
assert sse_route["mounted_app"] is True
|
||||
|
||||
def test_get_routes_for_mounted_app_mixed_routes(self):
|
||||
"""Test getting routes for mounted app with both regular routes and Mount objects."""
|
||||
# Mock the main mount route
|
||||
mock_mount_route = Mock()
|
||||
mock_mount_route.path = "/mcp"
|
||||
|
||||
# Mock sub-app
|
||||
mock_sub_app = Mock()
|
||||
mock_sub_app.routes = []
|
||||
|
||||
# Create a regular API route
|
||||
mock_api_route = Mock()
|
||||
mock_api_route.path = "/enabled"
|
||||
mock_api_route.methods = ["GET"]
|
||||
mock_api_route.name = "get_mcp_server_enabled"
|
||||
mock_endpoint = Mock()
|
||||
mock_endpoint.__name__ = "get_mcp_server_enabled"
|
||||
mock_api_route.endpoint = mock_endpoint
|
||||
mock_api_route.app = None
|
||||
|
||||
# Create Mount object
|
||||
mock_mount_base = Mock(spec=['path', 'name', 'endpoint', 'app'])
|
||||
mock_mount_base.path = ""
|
||||
mock_mount_base.name = ""
|
||||
mock_mount_base.endpoint = None
|
||||
mock_app_function = Mock()
|
||||
mock_app_function.__name__ = "handle_streamable_http_mcp"
|
||||
mock_mount_base.app = mock_app_function
|
||||
|
||||
mock_sub_app.routes.extend([mock_api_route, mock_mount_base])
|
||||
mock_mount_route.app = mock_sub_app
|
||||
|
||||
result = GetRoutes.get_routes_for_mounted_app(mock_mount_route)
|
||||
|
||||
# Should capture both the API route and the Mount object
|
||||
assert len(result) == 2
|
||||
|
||||
# Check API route
|
||||
api_route = next(r for r in result if r["path"] == "/mcp/enabled")
|
||||
assert api_route["methods"] == ["GET"]
|
||||
assert api_route["endpoint"] == "get_mcp_server_enabled"
|
||||
|
||||
# Check Mount object route
|
||||
mount_route = next(r for r in result if r["path"] == "/mcp")
|
||||
assert mount_route["endpoint"] == "handle_streamable_http_mcp"
|
||||
assert mount_route["mounted_app"] is True
|
||||
|
||||
def test_get_routes_for_mounted_app_with_static_files(self):
|
||||
"""
|
||||
Test getting routes for mounted app with StaticFiles object (reproduces AttributeError bug).
|
||||
|
||||
This test reproduces the exact stacktrace scenario:
|
||||
AttributeError: 'StaticFiles' object has no attribute '__name__'. Did you mean: '__ne__'?
|
||||
|
||||
The original bug occurred when the code tried to access endpoint_func.__name__
|
||||
directly on a StaticFiles object. The fix uses _safe_get_endpoint_name() which
|
||||
gracefully handles objects without __name__ by falling back to class name.
|
||||
"""
|
||||
# Mock the main mount route (e.g., /ui)
|
||||
mock_mount_route = Mock()
|
||||
mock_mount_route.path = "/ui"
|
||||
|
||||
# Mock sub-app with routes
|
||||
mock_sub_app = Mock()
|
||||
mock_sub_app.routes = []
|
||||
|
||||
# Create a mock StaticFiles route (this is the problematic case)
|
||||
mock_static_route = Mock(spec=['path', 'name', 'endpoint', 'app'])
|
||||
mock_static_route.path = ""
|
||||
mock_static_route.name = "ui"
|
||||
mock_static_route.endpoint = None
|
||||
|
||||
# Mock StaticFiles object - this is the key part that caused the AttributeError
|
||||
# Real StaticFiles objects don't have __name__ attribute
|
||||
# Create a mock that simulates StaticFiles behavior (no __name__ attribute)
|
||||
class StaticFiles:
|
||||
"""Mock class that simulates real StaticFiles without __name__ attribute"""
|
||||
pass
|
||||
|
||||
mock_static_files = StaticFiles()
|
||||
# Verify no __name__ attribute exists on the instance (reproduces bug condition)
|
||||
assert not hasattr(mock_static_files, '__name__')
|
||||
|
||||
mock_static_route.app = mock_static_files
|
||||
|
||||
mock_sub_app.routes.append(mock_static_route)
|
||||
mock_mount_route.app = mock_sub_app
|
||||
|
||||
# This should NOT raise AttributeError thanks to _safe_get_endpoint_name
|
||||
# In the old code, this would fail with: 'StaticFiles' object has no attribute '__name__'
|
||||
result = GetRoutes.get_routes_for_mounted_app(mock_mount_route)
|
||||
|
||||
# Should handle StaticFiles gracefully without throwing AttributeError
|
||||
assert len(result) == 1
|
||||
assert result[0]["path"] == "/ui"
|
||||
assert result[0]["methods"] == ["GET", "POST"] # Default methods
|
||||
assert result[0]["name"] == "ui"
|
||||
# Should fall back to class name since instance doesn't have __name__ attribute
|
||||
assert result[0]["endpoint"] == "StaticFiles" # Falls back to class name
|
||||
assert result[0]["mounted_app"] is True
|
||||
|
@@ -0,0 +1,287 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
import litellm
|
||||
from litellm.proxy.common_utils.http_parsing_utils import (
|
||||
_read_request_body,
|
||||
_safe_get_request_parsed_body,
|
||||
_safe_set_request_parsed_body,
|
||||
get_form_data,
|
||||
)
|
||||
from litellm.proxy._types import ProxyException
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_body_caching():
|
||||
"""
|
||||
Test that the request body is cached after the first read and subsequent
|
||||
calls use the cached version instead of parsing again.
|
||||
"""
|
||||
# Create a mock request with a JSON body
|
||||
mock_request = MagicMock()
|
||||
test_data = {"key": "value"}
|
||||
# Use AsyncMock for the body method
|
||||
mock_request.body = AsyncMock(return_value=orjson.dumps(test_data))
|
||||
mock_request.headers = {"content-type": "application/json"}
|
||||
mock_request.scope = {}
|
||||
|
||||
# First call should parse the body
|
||||
result1 = await _read_request_body(mock_request)
|
||||
assert result1 == test_data
|
||||
assert "parsed_body" in mock_request.scope
|
||||
assert mock_request.scope["parsed_body"] == (("key",), {"key": "value"})
|
||||
|
||||
# Verify the body was read once
|
||||
mock_request.body.assert_called_once()
|
||||
|
||||
# Reset the mock to track the second call
|
||||
mock_request.body.reset_mock()
|
||||
|
||||
# Second call should use the cached body
|
||||
result2 = await _read_request_body(mock_request)
|
||||
assert result2 == {"key": "value"}
|
||||
|
||||
# Verify the body was not read again
|
||||
mock_request.body.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_form_data_parsing():
|
||||
"""
|
||||
Test that form data is correctly parsed from the request.
|
||||
"""
|
||||
# Create a mock request with form data
|
||||
mock_request = MagicMock()
|
||||
test_data = {"name": "test_user", "message": "hello world"}
|
||||
|
||||
# Mock the form method to return the test data as an awaitable
|
||||
mock_request.form = AsyncMock(return_value=test_data)
|
||||
mock_request.headers = {"content-type": "application/x-www-form-urlencoded"}
|
||||
mock_request.scope = {}
|
||||
|
||||
# Parse the form data
|
||||
result = await _read_request_body(mock_request)
|
||||
|
||||
# Verify the form data was correctly parsed
|
||||
assert result == test_data
|
||||
assert "parsed_body" in mock_request.scope
|
||||
assert mock_request.scope["parsed_body"] == (
|
||||
("name", "message"),
|
||||
{"name": "test_user", "message": "hello world"},
|
||||
)
|
||||
|
||||
# Verify form() was called
|
||||
mock_request.form.assert_called_once()
|
||||
|
||||
# The body method should not be called for form data
|
||||
assert not hasattr(mock_request, "body") or not mock_request.body.called
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_request_body():
|
||||
"""
|
||||
Test handling of empty request bodies.
|
||||
"""
|
||||
# Create a mock request with an empty body
|
||||
mock_request = MagicMock()
|
||||
mock_request.body = AsyncMock(return_value=b"") # Empty bytes as an awaitable
|
||||
mock_request.headers = {"content-type": "application/json"}
|
||||
mock_request.scope = {}
|
||||
|
||||
# Parse the empty body
|
||||
result = await _read_request_body(mock_request)
|
||||
|
||||
# Verify an empty dict is returned
|
||||
assert result == {}
|
||||
assert "parsed_body" in mock_request.scope
|
||||
assert mock_request.scope["parsed_body"] == ((), {})
|
||||
|
||||
# Verify the body was read
|
||||
mock_request.body.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circular_reference_handling():
|
||||
"""
|
||||
Test that cached request body isn't modified when the returned result is modified.
|
||||
Demonstrates the mutable dictionary reference issue.
|
||||
"""
|
||||
# Create a mock request with initial data
|
||||
mock_request = MagicMock()
|
||||
initial_body = {
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
|
||||
mock_request.body = AsyncMock(return_value=orjson.dumps(initial_body))
|
||||
mock_request.headers = {"content-type": "application/json"}
|
||||
mock_request.scope = {}
|
||||
|
||||
# First parse
|
||||
result = await _read_request_body(mock_request)
|
||||
|
||||
# Verify initial parse
|
||||
assert result["model"] == "gpt-4"
|
||||
assert result["messages"] == [{"role": "user", "content": "Hello"}]
|
||||
|
||||
# Modify the result by adding proxy_server_request
|
||||
result["proxy_server_request"] = {
|
||||
"url": "http://0.0.0.0:4000/v1/chat/completions",
|
||||
"method": "POST",
|
||||
"headers": {"content-type": "application/json"},
|
||||
"body": result, # Creates circular reference
|
||||
}
|
||||
|
||||
# Second parse using the same request - will use the modified cached value
|
||||
result2 = await _read_request_body(mock_request)
|
||||
assert (
|
||||
"proxy_server_request" not in result2
|
||||
) # This will pass, showing the cache pollution
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_parsing_error_handling():
|
||||
"""
|
||||
Test that JSON parsing errors are properly handled and raise ProxyException
|
||||
with appropriate error messages.
|
||||
"""
|
||||
# Test case 1: Trailing comma error
|
||||
mock_request = MagicMock()
|
||||
invalid_json_with_trailing_comma = b'''{
|
||||
"model": "gpt-4o",
|
||||
"tools": [
|
||||
{
|
||||
"type": "mcp",
|
||||
"server_label": "litellm",
|
||||
"headers": {
|
||||
"x-litellm-api-key": "Bearer sk-1234",
|
||||
}
|
||||
}
|
||||
],
|
||||
"input": "Run available tools"
|
||||
}'''
|
||||
|
||||
mock_request.body = AsyncMock(return_value=invalid_json_with_trailing_comma)
|
||||
mock_request.headers = {"content-type": "application/json"}
|
||||
mock_request.scope = {}
|
||||
|
||||
# Should raise ProxyException for trailing comma
|
||||
with pytest.raises(ProxyException) as exc_info:
|
||||
await _read_request_body(mock_request)
|
||||
|
||||
assert exc_info.value.code == "400"
|
||||
assert "Invalid JSON payload" in exc_info.value.message
|
||||
assert "trailing comma" in exc_info.value.message
|
||||
|
||||
# Test case 2: Unquoted property name error
|
||||
mock_request2 = MagicMock()
|
||||
invalid_json_unquoted_property = b'''{
|
||||
"model": "gpt-4o",
|
||||
"tools": [
|
||||
{
|
||||
type: "mcp",
|
||||
"server_label": "litellm"
|
||||
}
|
||||
],
|
||||
"input": "Run available tools"
|
||||
}'''
|
||||
|
||||
mock_request2.body = AsyncMock(return_value=invalid_json_unquoted_property)
|
||||
mock_request2.headers = {"content-type": "application/json"}
|
||||
mock_request2.scope = {}
|
||||
|
||||
# Should raise ProxyException for unquoted property
|
||||
with pytest.raises(ProxyException) as exc_info2:
|
||||
await _read_request_body(mock_request2)
|
||||
|
||||
assert exc_info2.value.code == "400"
|
||||
assert "Invalid JSON payload" in exc_info2.value.message
|
||||
|
||||
# Test case 3: Valid JSON should work normally
|
||||
mock_request3 = MagicMock()
|
||||
valid_json = b'''{
|
||||
"model": "gpt-4o",
|
||||
"tools": [
|
||||
{
|
||||
"type": "mcp",
|
||||
"server_label": "litellm",
|
||||
"headers": {
|
||||
"x-litellm-api-key": "Bearer sk-1234"
|
||||
}
|
||||
}
|
||||
],
|
||||
"input": "Run available tools"
|
||||
}'''
|
||||
|
||||
mock_request3.body = AsyncMock(return_value=valid_json)
|
||||
mock_request3.headers = {"content-type": "application/json"}
|
||||
mock_request3.scope = {}
|
||||
|
||||
# Should parse successfully
|
||||
result = await _read_request_body(mock_request3)
|
||||
assert result["model"] == "gpt-4o"
|
||||
assert result["input"] == "Run available tools"
|
||||
assert len(result["tools"]) == 1
|
||||
assert result["tools"][0]["type"] == "mcp"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_form_data():
|
||||
"""
|
||||
Test that get_form_data correctly handles form data with array notation.
|
||||
Tests audio transcription parameters as a specific example.
|
||||
"""
|
||||
# Create a mock request with transcription form data
|
||||
mock_request = MagicMock()
|
||||
|
||||
# Create mock form data with array notation for timestamp_granularities
|
||||
mock_form_data = {
|
||||
"file": "file_object", # In a real request this would be an UploadFile
|
||||
"model": "gpt-4o-transcribe",
|
||||
"include[]": "logprobs", # Array notation
|
||||
"language": "en",
|
||||
"prompt": "Transcribe this audio file",
|
||||
"response_format": "json",
|
||||
"stream": "false",
|
||||
"temperature": "0.2",
|
||||
"timestamp_granularities[]": "word", # First array item
|
||||
"timestamp_granularities[]": "segment", # Second array item (would overwrite in dict, but handled by the function)
|
||||
}
|
||||
|
||||
# Mock the form method to return the test data
|
||||
mock_request.form = AsyncMock(return_value=mock_form_data)
|
||||
|
||||
# Call the function being tested
|
||||
result = await get_form_data(mock_request)
|
||||
|
||||
# Verify regular form fields are preserved
|
||||
assert result["file"] == "file_object"
|
||||
assert result["model"] == "gpt-4o-transcribe"
|
||||
assert result["language"] == "en"
|
||||
assert result["prompt"] == "Transcribe this audio file"
|
||||
assert result["response_format"] == "json"
|
||||
assert result["stream"] == "false"
|
||||
assert result["temperature"] == "0.2"
|
||||
|
||||
# Verify array fields are correctly parsed
|
||||
assert "include" in result
|
||||
assert isinstance(result["include"], list)
|
||||
assert "logprobs" in result["include"]
|
||||
|
||||
assert "timestamp_granularities" in result
|
||||
assert isinstance(result["timestamp_granularities"], list)
|
||||
# Note: In a real MultiDict, both values would be present
|
||||
# But in our mock dictionary the second value overwrites the first
|
||||
assert "segment" in result["timestamp_granularities"]
|
@@ -0,0 +1,90 @@
|
||||
from unittest.mock import MagicMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
|
||||
|
||||
|
||||
class TestGetFileContentsFromS3:
|
||||
"""Test suite for S3 config loading functionality."""
|
||||
|
||||
@patch('boto3.client')
|
||||
@patch('litellm.main.bedrock_converse_chat_completion')
|
||||
@patch('yaml.safe_load')
|
||||
def test_get_file_contents_from_s3_no_temp_file_creation(
|
||||
self, mock_yaml_load, mock_bedrock, mock_boto3_client
|
||||
):
|
||||
"""
|
||||
Test that get_file_contents_from_s3 doesn't create temporary files
|
||||
and uses yaml.safe_load directly on the S3 response content.
|
||||
|
||||
Note: It's critical that yaml.safe_load is used
|
||||
|
||||
Relevant issue/PR: https://github.com/BerriAI/litellm/pull/12078
|
||||
"""
|
||||
# Mock credentials
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.access_key = "test_access_key"
|
||||
mock_credentials.secret_key = "test_secret_key"
|
||||
mock_credentials.token = "test_token"
|
||||
mock_bedrock.get_credentials.return_value = mock_credentials
|
||||
|
||||
# Mock S3 client and response
|
||||
mock_s3_client = MagicMock()
|
||||
mock_boto3_client.return_value = mock_s3_client
|
||||
|
||||
# Mock S3 response with YAML content
|
||||
yaml_content = """
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
"""
|
||||
mock_response_body = MagicMock()
|
||||
mock_response_body.read.return_value = yaml_content.encode('utf-8')
|
||||
mock_s3_response = {
|
||||
'Body': mock_response_body
|
||||
}
|
||||
mock_s3_client.get_object.return_value = mock_s3_response
|
||||
|
||||
# Mock yaml.safe_load to return parsed config
|
||||
expected_config = {
|
||||
'model_list': [{
|
||||
'model_name': 'gpt-3.5-turbo',
|
||||
'litellm_params': {
|
||||
'model': 'gpt-3.5-turbo'
|
||||
}
|
||||
}]
|
||||
}
|
||||
mock_yaml_load.return_value = expected_config
|
||||
|
||||
# Call the function
|
||||
bucket_name = "test-bucket"
|
||||
object_key = "config.yaml"
|
||||
result = get_file_contents_from_s3(bucket_name, object_key)
|
||||
|
||||
# Assertions
|
||||
assert result == expected_config
|
||||
|
||||
# Verify S3 client was created with correct credentials
|
||||
mock_boto3_client.assert_called_once_with(
|
||||
"s3",
|
||||
aws_access_key_id="test_access_key",
|
||||
aws_secret_access_key="test_secret_key",
|
||||
aws_session_token="test_token"
|
||||
)
|
||||
|
||||
# Verify S3 get_object was called with correct parameters
|
||||
mock_s3_client.get_object.assert_called_once_with(
|
||||
Bucket=bucket_name,
|
||||
Key=object_key
|
||||
)
|
||||
|
||||
# Verify the response body was read and decoded
|
||||
mock_response_body.read.assert_called_once()
|
||||
|
||||
# Verify yaml.safe_load was called with the decoded content
|
||||
mock_yaml_load.assert_called_once_with(yaml_content)
|
||||
|
||||
|
@@ -0,0 +1,87 @@
|
||||
import pytest
|
||||
|
||||
from litellm.proxy.common_utils.openai_endpoint_utils import remove_sensitive_info_from_deployment
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_config, expected_config",
|
||||
[
|
||||
# Test case 1: Empty litellm_params
|
||||
(
|
||||
{
|
||||
"model_name": "test-model",
|
||||
"litellm_params": {}
|
||||
},
|
||||
{
|
||||
"model_name": "test-model",
|
||||
"litellm_params": {}
|
||||
}
|
||||
),
|
||||
# Test case 2: Full sensitive data removal, mixed secrets of azure, aws, gcp, and typical api_key
|
||||
(
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4",
|
||||
"api_key": "sk-sensitive-key-123",
|
||||
"client_secret": "~v8Q4W:Zp9gJ-3sTqX5aB@LkR2mNfYdC",
|
||||
"vertex_credentials": {"type": "service_account"},
|
||||
"aws_access_key_id": "AKIA123456789",
|
||||
"aws_secret_access_key": "secret-access-key",
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"temperature": 0.7
|
||||
},
|
||||
"model_info": {"id": "test-id"}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4",
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"temperature": 0.7
|
||||
},
|
||||
"model_info": {"id": "test-id"}
|
||||
}
|
||||
),
|
||||
# Test case 3: Partial sensitive data, api_key
|
||||
(
|
||||
{
|
||||
"model_name": "claude-3",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/claude-3",
|
||||
"api_key": "sk-anthropic-key",
|
||||
"temperature": 0.5
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/claude-3",
|
||||
"temperature": 0.5
|
||||
}
|
||||
}
|
||||
),
|
||||
# Test case 4: No sensitive data
|
||||
(
|
||||
{
|
||||
"model_name": "local-model",
|
||||
"litellm_params": {
|
||||
"model": "local/model",
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 100
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "local-model",
|
||||
"litellm_params": {
|
||||
"model": "local/model",
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 100
|
||||
}
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
def test_remove_sensitive_info_from_deployment(model_config: dict, expected_config: dict):
|
||||
sanitized_config = remove_sensitive_info_from_deployment(model_config)
|
||||
assert sanitized_config == expected_config
|
@@ -0,0 +1,322 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
|
||||
# Mock classes for testing
|
||||
class MockLiteLLMTeamMembership:
|
||||
async def update_many(
|
||||
self, where: Dict[str, Any], data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
# Mock the update_many method for litellm_teammembership
|
||||
return {"count": 1}
|
||||
|
||||
|
||||
class MockDB:
|
||||
def __init__(self):
|
||||
self.litellm_teammembership = MockLiteLLMTeamMembership()
|
||||
|
||||
|
||||
class MockPrismaClient:
|
||||
def __init__(self):
|
||||
self.data: Dict[str, List[Any]] = {
|
||||
"key": [],
|
||||
"user": [],
|
||||
"team": [],
|
||||
"budget": [],
|
||||
"enduser": [],
|
||||
}
|
||||
self.updated_data: Dict[str, List[Any]] = {
|
||||
"key": [],
|
||||
"user": [],
|
||||
"team": [],
|
||||
"budget": [],
|
||||
"enduser": [],
|
||||
}
|
||||
self.db = MockDB()
|
||||
|
||||
async def get_data(self, table_name, query_type, **kwargs):
|
||||
data = self.data.get(table_name, [])
|
||||
|
||||
# Handle specific filtering for budget table queries
|
||||
if table_name == "budget" and query_type == "find_all" and "reset_at" in kwargs:
|
||||
# Return budgets that need to be reset (simulate expired budgets)
|
||||
return [item for item in data if hasattr(item, "budget_reset_at")]
|
||||
|
||||
# Handle specific filtering for enduser table queries
|
||||
if (
|
||||
table_name == "enduser"
|
||||
and query_type == "find_all"
|
||||
and "budget_id_list" in kwargs
|
||||
):
|
||||
budget_id_list = kwargs["budget_id_list"]
|
||||
# Return endusers that match the budget IDs
|
||||
return [
|
||||
item
|
||||
for item in data
|
||||
if hasattr(item, "litellm_budget_table")
|
||||
and hasattr(item.litellm_budget_table, "budget_id")
|
||||
and item.litellm_budget_table.budget_id in budget_id_list
|
||||
]
|
||||
|
||||
# Handle key queries with expires and reset_at
|
||||
if (
|
||||
table_name == "key"
|
||||
and query_type == "find_all"
|
||||
and ("expires" in kwargs or "reset_at" in kwargs)
|
||||
):
|
||||
return [item for item in data if hasattr(item, "budget_reset_at")]
|
||||
|
||||
return data
|
||||
|
||||
async def update_data(self, query_type, data_list, table_name):
|
||||
self.updated_data[table_name] = data_list
|
||||
return data_list
|
||||
|
||||
|
||||
class MockProxyLogging:
|
||||
class MockServiceLogging:
|
||||
async def async_service_success_hook(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def async_service_failure_hook(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self):
|
||||
self.service_logging_obj = self.MockServiceLogging()
|
||||
|
||||
|
||||
# Test fixtures
|
||||
@pytest.fixture
|
||||
def mock_prisma_client():
|
||||
return MockPrismaClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_proxy_logging():
|
||||
return MockProxyLogging()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_budget_job(mock_prisma_client, mock_proxy_logging):
|
||||
return ResetBudgetJob(
|
||||
proxy_logging_obj=mock_proxy_logging, prisma_client=mock_prisma_client
|
||||
)
|
||||
|
||||
|
||||
# Helper function to run async tests
|
||||
async def run_async_test(coro):
|
||||
return await coro
|
||||
|
||||
|
||||
# Tests
|
||||
def test_reset_budget_for_key(reset_budget_job, mock_prisma_client):
|
||||
# Setup test data with timezone-aware datetime
|
||||
now = datetime.now(timezone.utc)
|
||||
test_key = type(
|
||||
"LiteLLM_VerificationToken",
|
||||
(),
|
||||
{
|
||||
"spend": 100.0,
|
||||
"budget_duration": "30d",
|
||||
"budget_reset_at": now,
|
||||
"id": "test-key-1",
|
||||
},
|
||||
)
|
||||
|
||||
mock_prisma_client.data["key"] = [test_key]
|
||||
|
||||
# Run the test
|
||||
asyncio.run(reset_budget_job.reset_budget_for_litellm_keys())
|
||||
|
||||
# Verify results
|
||||
assert len(mock_prisma_client.updated_data["key"]) == 1
|
||||
updated_key = mock_prisma_client.updated_data["key"][0]
|
||||
assert updated_key.spend == 0.0
|
||||
assert updated_key.budget_reset_at > now
|
||||
|
||||
|
||||
def test_reset_budget_for_user(reset_budget_job, mock_prisma_client):
|
||||
# Setup test data with timezone-aware datetime
|
||||
now = datetime.now(timezone.utc)
|
||||
test_user = type(
|
||||
"LiteLLM_UserTable",
|
||||
(),
|
||||
{
|
||||
"spend": 200.0,
|
||||
"budget_duration": "7d",
|
||||
"budget_reset_at": now,
|
||||
"id": "test-user-1",
|
||||
},
|
||||
)
|
||||
|
||||
mock_prisma_client.data["user"] = [test_user]
|
||||
|
||||
# Run the test
|
||||
asyncio.run(reset_budget_job.reset_budget_for_litellm_users())
|
||||
|
||||
# Verify results
|
||||
assert len(mock_prisma_client.updated_data["user"]) == 1
|
||||
updated_user = mock_prisma_client.updated_data["user"][0]
|
||||
assert updated_user.spend == 0.0
|
||||
assert updated_user.budget_reset_at > now
|
||||
|
||||
|
||||
def test_reset_budget_for_team(reset_budget_job, mock_prisma_client):
|
||||
# Setup test data with timezone-aware datetime
|
||||
now = datetime.now(timezone.utc)
|
||||
test_team = type(
|
||||
"LiteLLM_TeamTable",
|
||||
(),
|
||||
{
|
||||
"spend": 500.0,
|
||||
"budget_duration": "1mo",
|
||||
"budget_reset_at": now,
|
||||
"id": "test-team-1",
|
||||
},
|
||||
)
|
||||
|
||||
mock_prisma_client.data["team"] = [test_team]
|
||||
|
||||
# Run the test
|
||||
asyncio.run(reset_budget_job.reset_budget_for_litellm_teams())
|
||||
|
||||
# Verify results
|
||||
assert len(mock_prisma_client.updated_data["team"]) == 1
|
||||
updated_team = mock_prisma_client.updated_data["team"][0]
|
||||
assert updated_team.spend == 0.0
|
||||
assert updated_team.budget_reset_at > now
|
||||
|
||||
|
||||
def test_reset_budget_for_enduser(reset_budget_job, mock_prisma_client):
|
||||
# Setup test data
|
||||
now = datetime.now(timezone.utc)
|
||||
test_budget = type(
|
||||
"LiteLLM_BudgetTable",
|
||||
(),
|
||||
{
|
||||
"max_budget": 500.0,
|
||||
"budget_duration": "1d",
|
||||
"budget_reset_at": now,
|
||||
"budget_id": "test-budget-1",
|
||||
},
|
||||
)
|
||||
|
||||
test_enduser = type(
|
||||
"LiteLLM_EndUserTable",
|
||||
(),
|
||||
{
|
||||
"spend": 20.0,
|
||||
"litellm_budget_table": test_budget,
|
||||
"user_id": "test-enduser-1",
|
||||
},
|
||||
)
|
||||
|
||||
mock_prisma_client.data["budget"] = [test_budget]
|
||||
mock_prisma_client.data["enduser"] = [test_enduser]
|
||||
|
||||
# Run the test
|
||||
asyncio.run(reset_budget_job.reset_budget_for_litellm_budget_table())
|
||||
|
||||
# Verify results
|
||||
assert len(mock_prisma_client.updated_data["enduser"]) == 1
|
||||
assert len(mock_prisma_client.updated_data["budget"]) == 1
|
||||
updated_enduser = mock_prisma_client.updated_data["enduser"][0]
|
||||
updated_budget = mock_prisma_client.updated_data["budget"][0]
|
||||
assert updated_enduser.spend == 0.0
|
||||
assert updated_budget.budget_reset_at > now
|
||||
|
||||
|
||||
def test_reset_budget_all(reset_budget_job, mock_prisma_client):
|
||||
# Setup test data with timezone-aware datetime
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Create test objects for all three types
|
||||
test_key = type(
|
||||
"LiteLLM_VerificationToken",
|
||||
(),
|
||||
{
|
||||
"spend": 100.0,
|
||||
"budget_duration": "30d",
|
||||
"budget_reset_at": now,
|
||||
"id": "test-key-1",
|
||||
},
|
||||
)
|
||||
|
||||
test_user = type(
|
||||
"LiteLLM_UserTable",
|
||||
(),
|
||||
{
|
||||
"spend": 200.0,
|
||||
"budget_duration": "7d",
|
||||
"budget_reset_at": now,
|
||||
"id": "test-user-1",
|
||||
},
|
||||
)
|
||||
|
||||
test_team = type(
|
||||
"LiteLLM_TeamTable",
|
||||
(),
|
||||
{
|
||||
"spend": 500.0,
|
||||
"budget_duration": "1mo",
|
||||
"budget_reset_at": now,
|
||||
"id": "test-team-1",
|
||||
},
|
||||
)
|
||||
|
||||
test_budget = type(
|
||||
"LiteLLM_BudgetTable",
|
||||
(),
|
||||
{
|
||||
"max_budget": 500.0,
|
||||
"budget_duration": "1d",
|
||||
"budget_reset_at": now,
|
||||
"budget_id": "test-budget-1",
|
||||
},
|
||||
)
|
||||
|
||||
test_enduser = type(
|
||||
"LiteLLM_EndUserTable",
|
||||
(),
|
||||
{
|
||||
"spend": 20.0,
|
||||
"litellm_budget_table": test_budget,
|
||||
"user_id": "test-enduser-1",
|
||||
},
|
||||
)
|
||||
|
||||
mock_prisma_client.data["key"] = [test_key]
|
||||
mock_prisma_client.data["user"] = [test_user]
|
||||
mock_prisma_client.data["team"] = [test_team]
|
||||
mock_prisma_client.data["budget"] = [test_budget]
|
||||
mock_prisma_client.data["enduser"] = [test_enduser]
|
||||
|
||||
# Run the test
|
||||
asyncio.run(reset_budget_job.reset_budget())
|
||||
|
||||
# Verify results
|
||||
assert len(mock_prisma_client.updated_data["key"]) == 1
|
||||
assert len(mock_prisma_client.updated_data["user"]) == 1
|
||||
assert len(mock_prisma_client.updated_data["team"]) == 1
|
||||
assert len(mock_prisma_client.updated_data["enduser"]) == 1
|
||||
assert len(mock_prisma_client.updated_data["budget"]) == 1
|
||||
|
||||
# Check that all spends were reset to 0
|
||||
assert mock_prisma_client.updated_data["key"][0].spend == 0.0
|
||||
assert mock_prisma_client.updated_data["user"][0].spend == 0.0
|
||||
assert mock_prisma_client.updated_data["team"][0].spend == 0.0
|
||||
assert mock_prisma_client.updated_data["enduser"][0].spend == 0.0
|
@@ -0,0 +1,35 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
|
||||
|
||||
|
||||
def test_get_budget_reset_time():
|
||||
"""
|
||||
Test that the budget reset time is set to the first of the next month
|
||||
"""
|
||||
# Get the current date
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Calculate expected reset date (first of next month)
|
||||
if now.month == 12:
|
||||
expected_month = 1
|
||||
expected_year = now.year + 1
|
||||
else:
|
||||
expected_month = now.month + 1
|
||||
expected_year = now.year
|
||||
expected_reset_at = datetime(expected_year, expected_month, 1, tzinfo=timezone.utc)
|
||||
|
||||
# Verify budget_reset_at is set to first of next month
|
||||
assert get_budget_reset_time(budget_duration="1mo") == expected_reset_at
|
@@ -0,0 +1,318 @@
|
||||
# tests/litellm/proxy/common_utils/test_upsert_budget_membership.py
|
||||
import types
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.proxy.management_endpoints.common_utils import (
|
||||
_upsert_budget_and_membership,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: a fake Prisma transaction and a fake UserAPIKeyAuth object
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tx():
|
||||
"""
|
||||
Builds an object that looks just enough like the Prisma tx you use
|
||||
inside _upsert_budget_and_membership.
|
||||
"""
|
||||
# membership “table”
|
||||
membership = MagicMock()
|
||||
membership.update = AsyncMock()
|
||||
membership.upsert = AsyncMock()
|
||||
|
||||
# budget “table”
|
||||
budget = MagicMock()
|
||||
budget.update = AsyncMock()
|
||||
# budget.create returns a fake row that has .budget_id
|
||||
budget.create = AsyncMock(
|
||||
return_value=types.SimpleNamespace(budget_id="new-budget-123")
|
||||
)
|
||||
|
||||
tx = MagicMock()
|
||||
tx.litellm_teammembership = membership
|
||||
tx.litellm_budgettable = budget
|
||||
return tx
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_user():
|
||||
"""Cheap stand-in for UserAPIKeyAuth."""
|
||||
return types.SimpleNamespace(user_id="tester@example.com")
|
||||
|
||||
# TEST: max_budget is None, disconnect only
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_disconnect(mock_tx, fake_user):
|
||||
await _upsert_budget_and_membership(
|
||||
mock_tx,
|
||||
team_id="team-1",
|
||||
user_id="user-1",
|
||||
max_budget=None,
|
||||
existing_budget_id=None,
|
||||
user_api_key_dict=fake_user,
|
||||
)
|
||||
|
||||
mock_tx.litellm_teammembership.update.assert_awaited_once_with(
|
||||
where={"user_id_team_id": {"user_id": "user-1", "team_id": "team-1"}},
|
||||
data={"litellm_budget_table": {"disconnect": True}},
|
||||
)
|
||||
mock_tx.litellm_budgettable.update.assert_not_called()
|
||||
mock_tx.litellm_budgettable.create.assert_not_called()
|
||||
mock_tx.litellm_teammembership.upsert.assert_not_called()
|
||||
|
||||
|
||||
# TEST: existing budget id, creates new budget (current behavior)
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_with_existing_budget_id_creates_new(mock_tx, fake_user):
|
||||
"""
|
||||
Test that even when existing_budget_id is provided, the function creates a new budget.
|
||||
This reflects the current implementation behavior.
|
||||
"""
|
||||
await _upsert_budget_and_membership(
|
||||
mock_tx,
|
||||
team_id="team-2",
|
||||
user_id="user-2",
|
||||
max_budget=42.0,
|
||||
existing_budget_id="bud-999", # This parameter is currently unused
|
||||
user_api_key_dict=fake_user,
|
||||
)
|
||||
|
||||
# Should create a new budget, not update existing
|
||||
mock_tx.litellm_budgettable.create.assert_awaited_once_with(
|
||||
data={
|
||||
"max_budget": 42.0,
|
||||
"created_by": fake_user.user_id,
|
||||
"updated_by": fake_user.user_id,
|
||||
},
|
||||
include={"team_membership": True},
|
||||
)
|
||||
|
||||
# Should upsert team membership with the new budget ID
|
||||
new_budget_id = mock_tx.litellm_budgettable.create.return_value.budget_id
|
||||
mock_tx.litellm_teammembership.upsert.assert_awaited_once_with(
|
||||
where={"user_id_team_id": {"user_id": "user-2", "team_id": "team-2"}},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": "user-2",
|
||||
"team_id": "team-2",
|
||||
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
|
||||
},
|
||||
"update": {
|
||||
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Should NOT update existing budget
|
||||
mock_tx.litellm_budgettable.update.assert_not_called()
|
||||
mock_tx.litellm_teammembership.update.assert_not_called()
|
||||
|
||||
|
||||
# TEST: create new budget and link membership
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_create_and_link(mock_tx, fake_user):
|
||||
await _upsert_budget_and_membership(
|
||||
mock_tx,
|
||||
team_id="team-3",
|
||||
user_id="user-3",
|
||||
max_budget=99.9,
|
||||
existing_budget_id=None,
|
||||
user_api_key_dict=fake_user,
|
||||
)
|
||||
|
||||
mock_tx.litellm_budgettable.create.assert_awaited_once_with(
|
||||
data={
|
||||
"max_budget": 99.9,
|
||||
"created_by": fake_user.user_id,
|
||||
"updated_by": fake_user.user_id,
|
||||
},
|
||||
include={"team_membership": True},
|
||||
)
|
||||
|
||||
# Budget ID returned by the mocked create()
|
||||
bid = mock_tx.litellm_budgettable.create.return_value.budget_id
|
||||
|
||||
mock_tx.litellm_teammembership.upsert.assert_awaited_once_with(
|
||||
where={"user_id_team_id": {"user_id": "user-3", "team_id": "team-3"}},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": "user-3",
|
||||
"team_id": "team-3",
|
||||
"litellm_budget_table": {"connect": {"budget_id": bid}},
|
||||
},
|
||||
"update": {
|
||||
"litellm_budget_table": {"connect": {"budget_id": bid}},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
mock_tx.litellm_teammembership.update.assert_not_called()
|
||||
mock_tx.litellm_budgettable.update.assert_not_called()
|
||||
|
||||
|
||||
# TEST: create new budget and link membership, then create another new budget
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_create_then_create_another(mock_tx, fake_user):
|
||||
"""
|
||||
Test that multiple calls to _upsert_budget_and_membership create separate budgets,
|
||||
reflecting the current implementation behavior.
|
||||
"""
|
||||
# FIRST CALL – create new budget and link membership
|
||||
await _upsert_budget_and_membership(
|
||||
mock_tx,
|
||||
team_id="team-42",
|
||||
user_id="user-42",
|
||||
max_budget=10.0,
|
||||
existing_budget_id=None,
|
||||
user_api_key_dict=fake_user,
|
||||
)
|
||||
|
||||
# capture the budget id that create() returned
|
||||
created_bid = mock_tx.litellm_budgettable.create.return_value.budget_id
|
||||
|
||||
# sanity: we really did the create + upsert path
|
||||
mock_tx.litellm_budgettable.create.assert_awaited_once()
|
||||
mock_tx.litellm_teammembership.upsert.assert_awaited_once()
|
||||
|
||||
# SECOND CALL – reset call history and create another budget
|
||||
mock_tx.litellm_budgettable.create.reset_mock()
|
||||
mock_tx.litellm_teammembership.upsert.reset_mock()
|
||||
mock_tx.litellm_budgettable.update.reset_mock()
|
||||
|
||||
# Set up a new budget ID for the second create call
|
||||
mock_tx.litellm_budgettable.create.return_value = types.SimpleNamespace(budget_id="new-budget-456")
|
||||
|
||||
await _upsert_budget_and_membership(
|
||||
mock_tx,
|
||||
team_id="team-42",
|
||||
user_id="user-42",
|
||||
max_budget=25.0, # new limit
|
||||
existing_budget_id=created_bid, # this is ignored in current implementation
|
||||
user_api_key_dict=fake_user,
|
||||
)
|
||||
|
||||
# Should create another new budget (not update existing)
|
||||
mock_tx.litellm_budgettable.create.assert_awaited_once_with(
|
||||
data={
|
||||
"max_budget": 25.0,
|
||||
"created_by": fake_user.user_id,
|
||||
"updated_by": fake_user.user_id,
|
||||
},
|
||||
include={"team_membership": True},
|
||||
)
|
||||
|
||||
# Should upsert team membership with the new budget ID
|
||||
new_budget_id = mock_tx.litellm_budgettable.create.return_value.budget_id
|
||||
mock_tx.litellm_teammembership.upsert.assert_awaited_once_with(
|
||||
where={"user_id_team_id": {"user_id": "user-42", "team_id": "team-42"}},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": "user-42",
|
||||
"team_id": "team-42",
|
||||
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
|
||||
},
|
||||
"update": {
|
||||
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Should NOT call update
|
||||
mock_tx.litellm_budgettable.update.assert_not_called()
|
||||
|
||||
|
||||
# TEST: update rpm_limit for member with existing budget_id
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_rpm_limit_update_creates_new_budget(mock_tx, fake_user):
|
||||
"""
|
||||
Test that updating rpm_limit for a member with an existing budget_id
|
||||
creates a new budget with the new rpm/tpm limits and assigns it to the user.
|
||||
"""
|
||||
existing_budget_id = "existing-budget-456"
|
||||
|
||||
await _upsert_budget_and_membership(
|
||||
mock_tx,
|
||||
team_id="team-rpm-test",
|
||||
user_id="user-rpm-test",
|
||||
max_budget=50.0,
|
||||
existing_budget_id=existing_budget_id,
|
||||
user_api_key_dict=fake_user,
|
||||
tpm_limit=1000,
|
||||
rpm_limit=100, # updating rpm_limit
|
||||
)
|
||||
|
||||
# Should create a new budget with all the specified limits
|
||||
mock_tx.litellm_budgettable.create.assert_awaited_once_with(
|
||||
data={
|
||||
"max_budget": 50.0,
|
||||
"tpm_limit": 1000,
|
||||
"rpm_limit": 100,
|
||||
"created_by": fake_user.user_id,
|
||||
"updated_by": fake_user.user_id,
|
||||
},
|
||||
include={"team_membership": True},
|
||||
)
|
||||
|
||||
# Should NOT update the existing budget
|
||||
mock_tx.litellm_budgettable.update.assert_not_called()
|
||||
|
||||
# Should upsert team membership with the new budget ID
|
||||
new_budget_id = mock_tx.litellm_budgettable.create.return_value.budget_id
|
||||
mock_tx.litellm_teammembership.upsert.assert_awaited_once_with(
|
||||
where={"user_id_team_id": {"user_id": "user-rpm-test", "team_id": "team-rpm-test"}},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": "user-rpm-test",
|
||||
"team_id": "team-rpm-test",
|
||||
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
|
||||
},
|
||||
"update": {
|
||||
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# TEST: create new budget with only rpm_limit (no max_budget)
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_rpm_only_creates_new_budget(mock_tx, fake_user):
|
||||
"""
|
||||
Test that setting only rpm_limit creates a new budget with just the rpm_limit.
|
||||
"""
|
||||
await _upsert_budget_and_membership(
|
||||
mock_tx,
|
||||
team_id="team-rpm-only",
|
||||
user_id="user-rpm-only",
|
||||
max_budget=None,
|
||||
existing_budget_id=None,
|
||||
user_api_key_dict=fake_user,
|
||||
rpm_limit=50,
|
||||
)
|
||||
|
||||
# Should create a new budget with only rpm_limit
|
||||
mock_tx.litellm_budgettable.create.assert_awaited_once_with(
|
||||
data={
|
||||
"rpm_limit": 50,
|
||||
"created_by": fake_user.user_id,
|
||||
"updated_by": fake_user.user_id,
|
||||
},
|
||||
include={"team_membership": True},
|
||||
)
|
||||
|
||||
# Should upsert team membership with the new budget ID
|
||||
new_budget_id = mock_tx.litellm_budgettable.create.return_value.budget_id
|
||||
mock_tx.litellm_teammembership.upsert.assert_awaited_once_with(
|
||||
where={"user_id_team_id": {"user_id": "user-rpm-only", "team_id": "team-rpm-only"}},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": "user-rpm-only",
|
||||
"team_id": "team-rpm-only",
|
||||
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
|
||||
},
|
||||
"update": {
|
||||
"litellm_budget_table": {"connect": {"budget_id": new_budget_id}},
|
||||
},
|
||||
},
|
||||
)
|
Reference in New Issue
Block a user