Added LiteLLM to the stack
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,785 @@
|
||||
"""
|
||||
Unit Tests for the max parallel request limiter v3 for the proxy
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
||||
_PROXY_MaxParallelRequestsHandler_v3 as _PROXY_MaxParallelRequestsHandler,
|
||||
)
|
||||
from litellm.proxy.utils import InternalUsageCache, ProxyLogging, hash_token
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
@pytest.mark.asyncio
|
||||
async def test_sliding_window_rate_limit_v3(monkeypatch):
|
||||
"""
|
||||
Test the sliding window rate limiting functionality
|
||||
"""
|
||||
monkeypatch.setenv("LITELLM_RATE_LIMIT_WINDOW_SIZE", "2")
|
||||
_api_key = "sk-12345"
|
||||
_api_key = hash_token(_api_key)
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, rpm_limit=3)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
|
||||
internal_usage_cache=InternalUsageCache(local_cache)
|
||||
)
|
||||
|
||||
# Mock the batch_rate_limiter_script to simulate window expiry and use correct key construction
|
||||
window_starts: Dict[str, int] = {}
|
||||
|
||||
async def mock_batch_rate_limiter(*args, **kwargs):
|
||||
keys = kwargs.get("keys") if kwargs else args[0]
|
||||
args_list = kwargs.get("args") if kwargs else args[1]
|
||||
now = args_list[0]
|
||||
window_size = args_list[1]
|
||||
results = []
|
||||
for i in range(0, len(keys), 3):
|
||||
window_key = keys[i]
|
||||
counter_key = keys[i + 1]
|
||||
# Simulate window expiry
|
||||
prev_window = window_starts.get(window_key)
|
||||
prev_counter = await local_cache.async_get_cache(key=counter_key) or 0
|
||||
if prev_window is None or (now - prev_window) >= window_size:
|
||||
# Window expired, reset
|
||||
window_starts[window_key] = now
|
||||
new_counter = 1
|
||||
await local_cache.async_set_cache(
|
||||
key=window_key, value=now, ttl=window_size
|
||||
)
|
||||
await local_cache.async_set_cache(
|
||||
key=counter_key, value=new_counter, ttl=window_size
|
||||
)
|
||||
else:
|
||||
new_counter = prev_counter + 1
|
||||
await local_cache.async_set_cache(
|
||||
key=counter_key, value=new_counter, ttl=window_size
|
||||
)
|
||||
results.append(now)
|
||||
results.append(new_counter)
|
||||
return results
|
||||
|
||||
parallel_request_handler.batch_rate_limiter_script = mock_batch_rate_limiter
|
||||
|
||||
# First request should succeed
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
)
|
||||
|
||||
# Second request should succeed
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
)
|
||||
|
||||
# Third request should fail
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
data={},
|
||||
call_type="",
|
||||
)
|
||||
assert exc_info.value.status_code == 429
|
||||
assert "Rate limit exceeded" in str(exc_info.value.detail)
|
||||
|
||||
# Wait for window to expire (2 seconds)
|
||||
await asyncio.sleep(3)
|
||||
|
||||
print("WAITED 3 seconds")
|
||||
|
||||
print(f"local_cache: {local_cache.in_memory_cache.cache_dict}")
|
||||
|
||||
# After window expires, should be able to make requests again
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiter_script_return_values_v3(monkeypatch):
|
||||
"""
|
||||
Test that the rate limiter script returns both counter and window values correctly
|
||||
"""
|
||||
monkeypatch.setenv("LITELLM_RATE_LIMIT_WINDOW_SIZE", "2")
|
||||
_api_key = "sk-12345"
|
||||
_api_key = hash_token(_api_key)
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, rpm_limit=3)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
|
||||
internal_usage_cache=InternalUsageCache(local_cache)
|
||||
)
|
||||
|
||||
# Mock the batch_rate_limiter_script to simulate window expiry and use correct key construction
|
||||
window_starts: Dict[str, int] = {}
|
||||
|
||||
async def mock_batch_rate_limiter(*args, **kwargs):
|
||||
keys = kwargs.get("keys") if kwargs else args[0]
|
||||
args_list = kwargs.get("args") if kwargs else args[1]
|
||||
now = args_list[0]
|
||||
window_size = args_list[1]
|
||||
results = []
|
||||
for i in range(0, len(keys), 3):
|
||||
window_key = keys[i]
|
||||
counter_key = keys[i + 1]
|
||||
# Simulate window expiry
|
||||
prev_window = window_starts.get(window_key)
|
||||
prev_counter = await local_cache.async_get_cache(key=counter_key) or 0
|
||||
if prev_window is None or (now - prev_window) >= window_size:
|
||||
# Window expired, reset
|
||||
window_starts[window_key] = now
|
||||
new_counter = 1
|
||||
await local_cache.async_set_cache(
|
||||
key=window_key, value=now, ttl=window_size
|
||||
)
|
||||
await local_cache.async_set_cache(
|
||||
key=counter_key, value=new_counter, ttl=window_size
|
||||
)
|
||||
else:
|
||||
new_counter = prev_counter + 1
|
||||
await local_cache.async_set_cache(
|
||||
key=counter_key, value=new_counter, ttl=window_size
|
||||
)
|
||||
results.append(now)
|
||||
results.append(new_counter)
|
||||
return results
|
||||
|
||||
parallel_request_handler.batch_rate_limiter_script = mock_batch_rate_limiter
|
||||
|
||||
# Make first request
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
)
|
||||
|
||||
# Verify both counter and window values are stored in cache
|
||||
window_key = f"{{api_key:{_api_key}}}:window"
|
||||
counter_key = f"{{api_key:{_api_key}}}:requests"
|
||||
|
||||
window_value = await local_cache.async_get_cache(key=window_key)
|
||||
counter_value = await local_cache.async_get_cache(key=counter_key)
|
||||
|
||||
assert window_value is not None, "Window value should be stored in cache"
|
||||
assert counter_value is not None, "Counter value should be stored in cache"
|
||||
assert counter_value == 1, "Counter should be 1 after first request"
|
||||
|
||||
# Make second request
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
)
|
||||
|
||||
# Verify counter increased but window stayed same
|
||||
new_window_value = await local_cache.async_get_cache(key=window_key)
|
||||
new_counter_value = await local_cache.async_get_cache(key=counter_key)
|
||||
|
||||
assert (
|
||||
new_window_value == window_value
|
||||
), "Window value should not change within window"
|
||||
assert new_counter_value == 2, "Counter should be 2 after second request"
|
||||
|
||||
# Wait for window to expire
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# Make request after window expiry
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
)
|
||||
|
||||
# Verify new window and reset counter
|
||||
final_window_value = await local_cache.async_get_cache(key=window_key)
|
||||
final_counter_value = await local_cache.async_get_cache(key=counter_key)
|
||||
|
||||
assert final_window_value != window_value, "Window value should change after expiry"
|
||||
assert final_counter_value == 1, "Counter should reset to 1 after window expiry"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rate_limit_object",
|
||||
[
|
||||
"api_key",
|
||||
"model_per_key",
|
||||
"user",
|
||||
"end_user",
|
||||
"team",
|
||||
],
|
||||
)
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_router_call_tpm_v3(monkeypatch, rate_limit_object):
|
||||
"""
|
||||
Test normal router call with parallel request limiter v3 for TPM rate limiting
|
||||
"""
|
||||
monkeypatch.setenv("LITELLM_RATE_LIMIT_WINDOW_SIZE", "2")
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-turbo",
|
||||
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||
"rpm": 1440,
|
||||
},
|
||||
"model_info": {"id": 1},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||
"rpm": 6,
|
||||
},
|
||||
"model_info": {"id": 2},
|
||||
},
|
||||
]
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
set_verbose=False,
|
||||
num_retries=3,
|
||||
) # type: ignore
|
||||
|
||||
_api_key = "sk-12345"
|
||||
_api_key = hash_token(_api_key)
|
||||
if rate_limit_object == "api_key":
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, tpm_limit=10)
|
||||
elif rate_limit_object == "user":
|
||||
user_api_key_dict = UserAPIKeyAuth(user_id="12345", user_tpm_limit=10)
|
||||
elif rate_limit_object == "team":
|
||||
user_api_key_dict = UserAPIKeyAuth(team_id="12345", team_tpm_limit=10)
|
||||
elif rate_limit_object == "end_user":
|
||||
user_api_key_dict = UserAPIKeyAuth(end_user_id="12345", end_user_tpm_limit=10)
|
||||
elif rate_limit_object == "model_per_key":
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=_api_key,
|
||||
metadata={"model_tpm_limit": {"azure-model": 10}},
|
||||
)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
|
||||
internal_usage_cache=InternalUsageCache(local_cache)
|
||||
)
|
||||
|
||||
# Mock the batch_rate_limiter_script to simulate window expiry and use correct key construction
|
||||
window_starts: Dict[str, int] = {}
|
||||
|
||||
async def mock_batch_rate_limiter(*args, **kwargs):
|
||||
print(f"args: {args}, kwargs: {kwargs}")
|
||||
keys = kwargs.get("keys") if kwargs else args[0]
|
||||
args_list = kwargs.get("args") if kwargs else args[1]
|
||||
now = args_list[0]
|
||||
window_size = args_list[1]
|
||||
results = []
|
||||
for i in range(0, len(keys), 3):
|
||||
window_key = keys[i]
|
||||
counter_key = keys[i + 1]
|
||||
# Simulate window expiry
|
||||
prev_window = window_starts.get(window_key)
|
||||
prev_counter = await local_cache.async_get_cache(key=counter_key) or 0
|
||||
if prev_window is None or (now - prev_window) >= window_size:
|
||||
# Window expired, reset
|
||||
window_starts[window_key] = now
|
||||
new_counter = 1
|
||||
await local_cache.async_set_cache(
|
||||
key=window_key, value=now, ttl=window_size
|
||||
)
|
||||
await local_cache.async_set_cache(
|
||||
key=counter_key, value=new_counter, ttl=window_size
|
||||
)
|
||||
else:
|
||||
new_counter = prev_counter + 1
|
||||
await local_cache.async_set_cache(
|
||||
key=counter_key, value=new_counter, ttl=window_size
|
||||
)
|
||||
results.append(now)
|
||||
results.append(new_counter)
|
||||
return results
|
||||
|
||||
parallel_request_handler.batch_rate_limiter_script = mock_batch_rate_limiter
|
||||
monkeypatch.setattr(litellm, "callbacks", [parallel_request_handler])
|
||||
|
||||
# Helper to get the correct value for key construction
|
||||
def get_value_for_key(rate_limit_object, user_api_key_dict, model_name):
|
||||
if rate_limit_object == "api_key":
|
||||
return user_api_key_dict.api_key
|
||||
elif rate_limit_object == "user":
|
||||
return user_api_key_dict.user_id
|
||||
elif rate_limit_object == "team":
|
||||
return user_api_key_dict.team_id
|
||||
elif rate_limit_object == "end_user":
|
||||
return user_api_key_dict.end_user_id
|
||||
elif rate_limit_object == "model_per_key":
|
||||
return f"{user_api_key_dict.api_key}:{model_name}"
|
||||
return None
|
||||
|
||||
value = get_value_for_key(rate_limit_object, user_api_key_dict, "azure-model")
|
||||
counter_key = parallel_request_handler.create_rate_limit_keys(
|
||||
rate_limit_object, value, "tokens"
|
||||
)
|
||||
|
||||
# First request should succeed
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
data={"model": "azure-model"},
|
||||
call_type="",
|
||||
)
|
||||
|
||||
# normal call
|
||||
response = await router.acompletion(
|
||||
model="azure-model",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
metadata={
|
||||
"user_api_key": _api_key,
|
||||
"user_api_key_user_id": user_api_key_dict.user_id,
|
||||
"user_api_key_team_id": user_api_key_dict.team_id,
|
||||
"user_api_key_end_user_id": user_api_key_dict.end_user_id,
|
||||
},
|
||||
mock_response="hello",
|
||||
)
|
||||
await asyncio.sleep(1) # success is done in a separate thread
|
||||
|
||||
# Verify the token count is tracked
|
||||
counter_value = await local_cache.async_get_cache(key=counter_key)
|
||||
print(f"local_cache: {local_cache.in_memory_cache.cache_dict}")
|
||||
|
||||
assert (
|
||||
counter_value is not None
|
||||
), f"Counter value should be stored in cache for {counter_key}"
|
||||
|
||||
# Make another request to test rate limiting
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
data={"model": "azure-model"},
|
||||
call_type="",
|
||||
)
|
||||
|
||||
# Wait for window to expire
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# Make request after window expiry
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
data={"model": "azure-model"},
|
||||
call_type="",
|
||||
)
|
||||
|
||||
# Verify new window and reset counter
|
||||
final_counter_value = await local_cache.async_get_cache(key=counter_key)
|
||||
|
||||
assert final_counter_value == 1, "Counter should reset to 1 after window expiry"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"token_rate_limit_type",
|
||||
["input", "output", "total"],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_rate_limit_type_respected_v3(monkeypatch, token_rate_limit_type):
|
||||
"""
|
||||
Test that the token_rate_limit_type setting is respected when incrementing usage
|
||||
"""
|
||||
# Set up environment and mock general_settings
|
||||
monkeypatch.setenv("LITELLM_RATE_LIMIT_WINDOW_SIZE", "60")
|
||||
|
||||
_api_key = "sk-12345"
|
||||
_api_key = hash_token(_api_key)
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, tpm_limit=100)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
|
||||
internal_usage_cache=InternalUsageCache(local_cache)
|
||||
)
|
||||
|
||||
# Mock the get_rate_limit_type method directly since it imports general_settings internally
|
||||
def mock_get_rate_limit_type():
|
||||
return token_rate_limit_type
|
||||
|
||||
monkeypatch.setattr(
|
||||
parallel_request_handler, "get_rate_limit_type", mock_get_rate_limit_type
|
||||
)
|
||||
|
||||
# Create a mock response with different token counts
|
||||
mock_usage = Usage(prompt_tokens=20, completion_tokens=30, total_tokens=50)
|
||||
mock_response = ModelResponse(
|
||||
id="mock-response",
|
||||
object="chat.completion",
|
||||
created=int(datetime.now().timestamp()),
|
||||
model="gpt-3.5-turbo",
|
||||
usage=mock_usage,
|
||||
choices=[],
|
||||
)
|
||||
|
||||
# Create mock kwargs for the success event
|
||||
mock_kwargs = {
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key": _api_key,
|
||||
"user_api_key_user_id": None,
|
||||
"user_api_key_team_id": None,
|
||||
"user_api_key_end_user_id": None,
|
||||
}
|
||||
},
|
||||
"model": "gpt-3.5-turbo",
|
||||
}
|
||||
|
||||
# Mock the pipeline increment method to capture the operations
|
||||
captured_operations = []
|
||||
|
||||
async def mock_increment_pipeline(increment_list, **kwargs):
|
||||
captured_operations.extend(increment_list)
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(
|
||||
parallel_request_handler.internal_usage_cache.dual_cache,
|
||||
"async_increment_cache_pipeline",
|
||||
mock_increment_pipeline,
|
||||
)
|
||||
|
||||
# Call the success event handler
|
||||
await parallel_request_handler.async_log_success_event(
|
||||
kwargs=mock_kwargs,
|
||||
response_obj=mock_response,
|
||||
start_time=datetime.now(),
|
||||
end_time=datetime.now(),
|
||||
)
|
||||
|
||||
# Verify that the correct token count was used based on the rate limit type
|
||||
assert (
|
||||
len(captured_operations) == 2
|
||||
), "Should have 2 operations: max_parallel_requests decrement and TPM increment"
|
||||
|
||||
# Find the TPM increment operation (not the max_parallel_requests decrement)
|
||||
tpm_operation = None
|
||||
for op in captured_operations:
|
||||
if op["key"].endswith(":tokens"):
|
||||
tpm_operation = op
|
||||
break
|
||||
|
||||
assert tpm_operation is not None, "Should have a TPM increment operation"
|
||||
|
||||
# Check that the correct token count was used
|
||||
expected_tokens = {
|
||||
"input": mock_usage.prompt_tokens, # 20
|
||||
"output": mock_usage.completion_tokens, # 50 (Note: implementation uses total_tokens for output, which might be a bug)
|
||||
"total": mock_usage.total_tokens, # 50
|
||||
}
|
||||
|
||||
assert (
|
||||
tpm_operation["increment_value"] == expected_tokens[token_rate_limit_type]
|
||||
), f"Expected {expected_tokens[token_rate_limit_type]} tokens for type '{token_rate_limit_type}', got {tpm_operation['increment_value']}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_log_failure_event_v3():
|
||||
"""
|
||||
Simple test for async_log_failure_event - should decrement max_parallel_requests by 1
|
||||
"""
|
||||
_api_key = "sk-12345"
|
||||
_api_key = hash_token(_api_key)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
|
||||
internal_usage_cache=InternalUsageCache(local_cache)
|
||||
)
|
||||
|
||||
# Mock kwargs with user_api_key
|
||||
mock_kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}}
|
||||
|
||||
# Capture pipeline operations
|
||||
captured_ops = []
|
||||
|
||||
async def mock_pipeline(increment_list, **kwargs):
|
||||
captured_ops.extend(increment_list)
|
||||
|
||||
parallel_request_handler.internal_usage_cache.dual_cache.async_increment_cache_pipeline = (
|
||||
mock_pipeline
|
||||
)
|
||||
|
||||
# Call async_log_failure_event
|
||||
await parallel_request_handler.async_log_failure_event(
|
||||
kwargs=mock_kwargs, response_obj=None, start_time=None, end_time=None
|
||||
)
|
||||
|
||||
# Verify correct operation was created
|
||||
assert len(captured_ops) == 1
|
||||
op = captured_ops[0]
|
||||
assert op["key"] == f"{{api_key:{_api_key}}}:max_parallel_requests"
|
||||
assert op["increment_value"] == -1
|
||||
assert op["ttl"] == 60 # default window size
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_rate_limit_only_called_when_limits_exist_v3():
|
||||
"""
|
||||
Test that should_rate_limit is only called when actual rate limits are configured.
|
||||
This verifies the optimization that avoids unnecessary rate limit checks.
|
||||
"""
|
||||
_api_key = "sk-12345"
|
||||
_api_key = hash_token(_api_key)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
|
||||
internal_usage_cache=InternalUsageCache(local_cache)
|
||||
)
|
||||
|
||||
# Mock should_rate_limit to track if it's called
|
||||
should_rate_limit_called = False
|
||||
|
||||
async def mock_should_rate_limit(*args, **kwargs):
|
||||
nonlocal should_rate_limit_called
|
||||
should_rate_limit_called = True
|
||||
return {"overall_code": "OK", "statuses": []}
|
||||
|
||||
parallel_request_handler.should_rate_limit = mock_should_rate_limit
|
||||
|
||||
# Test 1: No rate limits configured - should_rate_limit should NOT be called
|
||||
should_rate_limit_called = False
|
||||
user_api_key_dict_no_limits = UserAPIKeyAuth(
|
||||
api_key=_api_key,
|
||||
user_id="test_user",
|
||||
team_id="test_team",
|
||||
end_user_id="test_end_user",
|
||||
# No rpm_limit, tpm_limit, max_parallel_requests, etc.
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict_no_limits,
|
||||
cache=local_cache,
|
||||
data={"model": "gpt-3.5-turbo"},
|
||||
call_type="",
|
||||
)
|
||||
|
||||
assert (
|
||||
not should_rate_limit_called
|
||||
), "should_rate_limit should not be called when no rate limits are configured"
|
||||
|
||||
# Test 2: API key rate limits configured - should_rate_limit SHOULD be called
|
||||
should_rate_limit_called = False
|
||||
user_api_key_dict_with_api_limits = UserAPIKeyAuth(
|
||||
api_key=_api_key,
|
||||
rpm_limit=100, # Rate limit configured
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict_with_api_limits,
|
||||
cache=local_cache,
|
||||
data={"model": "gpt-3.5-turbo"},
|
||||
call_type="",
|
||||
)
|
||||
|
||||
assert (
|
||||
should_rate_limit_called
|
||||
), "should_rate_limit should be called when API key rate limits are configured"
|
||||
|
||||
# Test 3: User rate limits configured - should_rate_limit SHOULD be called
|
||||
should_rate_limit_called = False
|
||||
user_api_key_dict_with_user_limits = UserAPIKeyAuth(
|
||||
api_key=_api_key,
|
||||
user_id="test_user",
|
||||
user_tpm_limit=1000, # User rate limit configured
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict_with_user_limits,
|
||||
cache=local_cache,
|
||||
data={"model": "gpt-3.5-turbo"},
|
||||
call_type="",
|
||||
)
|
||||
|
||||
assert (
|
||||
should_rate_limit_called
|
||||
), "should_rate_limit should be called when user rate limits are configured"
|
||||
|
||||
# Test 4: Team rate limits configured - should_rate_limit SHOULD be called
|
||||
should_rate_limit_called = False
|
||||
user_api_key_dict_with_team_limits = UserAPIKeyAuth(
|
||||
api_key=_api_key,
|
||||
team_id="test_team",
|
||||
team_rpm_limit=500, # Team rate limit configured
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict_with_team_limits,
|
||||
cache=local_cache,
|
||||
data={"model": "gpt-3.5-turbo"},
|
||||
call_type="",
|
||||
)
|
||||
|
||||
assert (
|
||||
should_rate_limit_called
|
||||
), "should_rate_limit should be called when team rate limits are configured"
|
||||
|
||||
# Test 5: End user rate limits configured - should_rate_limit SHOULD be called
|
||||
should_rate_limit_called = False
|
||||
user_api_key_dict_with_end_user_limits = UserAPIKeyAuth(
|
||||
api_key=_api_key,
|
||||
end_user_id="test_end_user",
|
||||
end_user_rpm_limit=200, # End user rate limit configured
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict_with_end_user_limits,
|
||||
cache=local_cache,
|
||||
data={"model": "gpt-3.5-turbo"},
|
||||
call_type="",
|
||||
)
|
||||
|
||||
assert (
|
||||
should_rate_limit_called
|
||||
), "should_rate_limit should be called when end user rate limits are configured"
|
||||
|
||||
# Test 6: Max parallel requests configured - should_rate_limit SHOULD be called
|
||||
should_rate_limit_called = False
|
||||
user_api_key_dict_with_parallel_limits = UserAPIKeyAuth(
|
||||
api_key=_api_key,
|
||||
max_parallel_requests=5, # Max parallel requests configured
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict_with_parallel_limits,
|
||||
cache=local_cache,
|
||||
data={"model": "gpt-3.5-turbo"},
|
||||
call_type="",
|
||||
)
|
||||
|
||||
assert (
|
||||
should_rate_limit_called
|
||||
), "should_rate_limit should be called when max parallel requests are configured"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_specific_rate_limits_only_called_when_configured_v3():
|
||||
"""
|
||||
Test that model-specific rate limits only trigger should_rate_limit when actually configured for the requested model.
|
||||
"""
|
||||
from litellm.proxy.auth.auth_utils import (
|
||||
get_key_model_rpm_limit,
|
||||
get_key_model_tpm_limit,
|
||||
)
|
||||
|
||||
_api_key = "sk-12345"
|
||||
_api_key = hash_token(_api_key)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
|
||||
internal_usage_cache=InternalUsageCache(local_cache)
|
||||
)
|
||||
|
||||
# Mock should_rate_limit to track if it's called
|
||||
should_rate_limit_called = False
|
||||
|
||||
async def mock_should_rate_limit(*args, **kwargs):
|
||||
nonlocal should_rate_limit_called
|
||||
should_rate_limit_called = True
|
||||
return {"overall_code": "OK", "statuses": []}
|
||||
|
||||
parallel_request_handler.should_rate_limit = mock_should_rate_limit
|
||||
|
||||
# Test 1: Model-specific rate limits configured but for different model - should NOT be called
|
||||
should_rate_limit_called = False
|
||||
user_api_key_dict_with_model_limits = UserAPIKeyAuth(
|
||||
api_key=_api_key,
|
||||
metadata={
|
||||
"model_tpm_limit": {"gpt-4": 1000}
|
||||
}, # Rate limit for gpt-4, not gpt-3.5-turbo
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict_with_model_limits,
|
||||
cache=local_cache,
|
||||
data={"model": "gpt-3.5-turbo"}, # Requesting different model
|
||||
call_type="",
|
||||
)
|
||||
|
||||
assert (
|
||||
not should_rate_limit_called
|
||||
), "should_rate_limit should not be called when model-specific limits don't match requested model"
|
||||
|
||||
# Test 2: Model-specific rate limits configured for requested model - SHOULD be called
|
||||
should_rate_limit_called = False
|
||||
user_api_key_dict_with_matching_model_limits = UserAPIKeyAuth(
|
||||
api_key=_api_key,
|
||||
metadata={
|
||||
"model_tpm_limit": {"gpt-3.5-turbo": 1000}
|
||||
}, # Rate limit for requested model
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict_with_matching_model_limits,
|
||||
cache=local_cache,
|
||||
data={"model": "gpt-3.5-turbo"}, # Requesting same model
|
||||
call_type="",
|
||||
)
|
||||
|
||||
assert (
|
||||
should_rate_limit_called
|
||||
), "should_rate_limit should be called when model-specific limits match requested model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_member_rate_limits_v3():
|
||||
"""
|
||||
Test that team member RPM/TPM rate limits are properly applied for team member combinations.
|
||||
"""
|
||||
_api_key = "sk-12345"
|
||||
_api_key = hash_token(_api_key)
|
||||
_team_id = "team_123"
|
||||
_user_id = "user_456"
|
||||
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=_api_key,
|
||||
team_id=_team_id,
|
||||
user_id=_user_id,
|
||||
team_member_rpm_limit=10,
|
||||
team_member_tpm_limit=1000,
|
||||
)
|
||||
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
|
||||
internal_usage_cache=InternalUsageCache(local_cache)
|
||||
)
|
||||
|
||||
# Mock should_rate_limit to capture the descriptors
|
||||
captured_descriptors = None
|
||||
original_should_rate_limit = parallel_request_handler.should_rate_limit
|
||||
|
||||
async def mock_should_rate_limit(descriptors, **kwargs):
|
||||
nonlocal captured_descriptors
|
||||
captured_descriptors = descriptors
|
||||
# Return OK response to avoid HTTPException
|
||||
return {
|
||||
"overall_code": "OK",
|
||||
"statuses": []
|
||||
}
|
||||
|
||||
parallel_request_handler.should_rate_limit = mock_should_rate_limit
|
||||
|
||||
# Test the pre-call hook
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
data={"model": "gpt-3.5-turbo"},
|
||||
call_type="",
|
||||
)
|
||||
|
||||
# Verify team member descriptor was created
|
||||
assert captured_descriptors is not None, "Rate limit descriptors should be captured"
|
||||
|
||||
team_member_descriptor = None
|
||||
for descriptor in captured_descriptors:
|
||||
if descriptor["key"] == "team_member":
|
||||
team_member_descriptor = descriptor
|
||||
break
|
||||
|
||||
assert team_member_descriptor is not None, "Team member descriptor should be present"
|
||||
assert team_member_descriptor["value"] == f"{_team_id}:{_user_id}", "Team member value should combine team_id and user_id"
|
||||
assert team_member_descriptor["rate_limit"]["requests_per_unit"] == 10, "Team member RPM limit should be set"
|
||||
assert team_member_descriptor["rate_limit"]["tokens_per_unit"] == 1000, "Team member TPM limit should be set"
|
@@ -0,0 +1,128 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.hooks.proxy_track_cost_callback import _ProxyDBLogger
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_post_call_failure_hook():
|
||||
# Setup
|
||||
logger = _ProxyDBLogger()
|
||||
|
||||
# Mock user_api_key_dict
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
key_alias="test_alias",
|
||||
user_email="test@example.com",
|
||||
user_id="test_user_id",
|
||||
team_id="test_team_id",
|
||||
org_id="test_org_id",
|
||||
team_alias="test_team_alias",
|
||||
end_user_id="test_end_user_id",
|
||||
)
|
||||
|
||||
# Mock request data
|
||||
request_data = {
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"metadata": {"original_key": "original_value"},
|
||||
"proxy_server_request": {"request_id": "test_request_id"},
|
||||
}
|
||||
|
||||
# Mock exception
|
||||
original_exception = Exception("Test exception")
|
||||
|
||||
# Mock update_database function
|
||||
with patch(
|
||||
"litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter.update_database",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_database:
|
||||
# Call the method
|
||||
await logger.async_post_call_failure_hook(
|
||||
request_data=request_data,
|
||||
original_exception=original_exception,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
mock_update_database.assert_called_once()
|
||||
|
||||
# Check the arguments passed to update_database
|
||||
call_args = mock_update_database.call_args[1]
|
||||
print("call_args", json.dumps(call_args, indent=4, default=str))
|
||||
assert call_args["token"] == "test_api_key"
|
||||
assert call_args["response_cost"] == 0.0
|
||||
assert call_args["user_id"] == "test_user_id"
|
||||
assert call_args["end_user_id"] == "test_end_user_id"
|
||||
assert call_args["team_id"] == "test_team_id"
|
||||
assert call_args["org_id"] == "test_org_id"
|
||||
assert call_args["completion_response"] == original_exception
|
||||
|
||||
# Check that metadata was properly updated
|
||||
assert "litellm_params" in call_args["kwargs"]
|
||||
assert call_args["kwargs"]["litellm_params"]["proxy_server_request"] == {
|
||||
"request_id": "test_request_id"
|
||||
}
|
||||
metadata = call_args["kwargs"]["litellm_params"]["metadata"]
|
||||
assert metadata["user_api_key"] == "test_api_key"
|
||||
assert metadata["status"] == "failure"
|
||||
assert "error_information" in metadata
|
||||
assert metadata["original_key"] == "original_value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_post_call_failure_hook_non_llm_route():
|
||||
# Setup
|
||||
logger = _ProxyDBLogger()
|
||||
|
||||
# Mock user_api_key_dict with a non-LLM route
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
key_alias="test_alias",
|
||||
user_email="test@example.com",
|
||||
user_id="test_user_id",
|
||||
team_id="test_team_id",
|
||||
org_id="test_org_id",
|
||||
team_alias="test_team_alias",
|
||||
end_user_id="test_end_user_id",
|
||||
request_route="/custom/route", # Non-LLM route
|
||||
)
|
||||
|
||||
# Mock request data
|
||||
request_data = {
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"metadata": {"original_key": "original_value"},
|
||||
"proxy_server_request": {"request_id": "test_request_id"},
|
||||
}
|
||||
|
||||
# Mock exception
|
||||
original_exception = Exception("Test exception")
|
||||
|
||||
# Mock update_database function
|
||||
with patch(
|
||||
"litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter.update_database",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_database:
|
||||
# Call the method
|
||||
await logger.async_post_call_failure_hook(
|
||||
request_data=request_data,
|
||||
original_exception=original_exception,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
# Assert that update_database was NOT called for non-LLM routes
|
||||
mock_update_database.assert_not_called()
|
Reference in New Issue
Block a user