Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,374 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_ObjectPermissionTable,
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_UserTable,
|
||||
LitellmUserRoles,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
SSOUserDefinedValues,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
ExperimentalUIJWTToken,
|
||||
_can_object_call_vector_stores,
|
||||
get_user_object,
|
||||
vector_store_access_check,
|
||||
)
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper
|
||||
from litellm.utils import get_utc_datetime
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_salt_key(monkeypatch):
|
||||
"""Automatically set LITELLM_SALT_KEY for all tests"""
|
||||
monkeypatch.setenv("LITELLM_SALT_KEY", "sk-1234")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_sso_user_defined_values():
|
||||
return LiteLLM_UserTable(
|
||||
user_id="test_user",
|
||||
user_email="test@example.com",
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN.value,
|
||||
models=["gpt-3.5-turbo"],
|
||||
max_budget=100.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_sso_user_defined_values():
|
||||
return LiteLLM_UserTable(
|
||||
user_id="test_user",
|
||||
user_email="test@example.com",
|
||||
user_role=None, # Missing user role
|
||||
models=["gpt-3.5-turbo"],
|
||||
max_budget=100.0,
|
||||
)
|
||||
|
||||
|
||||
def test_get_experimental_ui_login_jwt_auth_token_valid(valid_sso_user_defined_values):
|
||||
"""Test generating JWT token with valid user role"""
|
||||
token = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
|
||||
valid_sso_user_defined_values
|
||||
)
|
||||
|
||||
# Decrypt and verify token contents
|
||||
decrypted_token = decrypt_value_helper(
|
||||
token, key="ui_hash_key", exception_type="debug"
|
||||
)
|
||||
# Check that decrypted_token is not None before using json.loads
|
||||
assert decrypted_token is not None
|
||||
token_data = json.loads(decrypted_token)
|
||||
|
||||
assert token_data["user_id"] == "test_user"
|
||||
assert token_data["user_role"] == LitellmUserRoles.PROXY_ADMIN.value
|
||||
assert token_data["models"] == ["gpt-3.5-turbo"]
|
||||
assert token_data["max_budget"] == litellm.max_ui_session_budget
|
||||
|
||||
# Verify expiration time is set and valid
|
||||
assert "expires" in token_data
|
||||
expires = datetime.fromisoformat(token_data["expires"].replace("Z", "+00:00"))
|
||||
assert expires > get_utc_datetime()
|
||||
assert expires <= get_utc_datetime() + timedelta(minutes=10)
|
||||
|
||||
|
||||
def test_get_experimental_ui_login_jwt_auth_token_invalid(
|
||||
invalid_sso_user_defined_values,
|
||||
):
|
||||
"""Test generating JWT token with missing user role"""
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
|
||||
invalid_sso_user_defined_values
|
||||
)
|
||||
|
||||
assert str(exc_info.value) == "User role is required for experimental UI login"
|
||||
|
||||
|
||||
def test_get_key_object_from_ui_hash_key_valid(
|
||||
valid_sso_user_defined_values, monkeypatch
|
||||
):
|
||||
"""Test getting key object from valid UI hash key"""
|
||||
monkeypatch.setenv("EXPERIMENTAL_UI_LOGIN", "True")
|
||||
# Generate a valid token
|
||||
token = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
|
||||
valid_sso_user_defined_values
|
||||
)
|
||||
|
||||
# Get key object
|
||||
key_object = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(token)
|
||||
|
||||
assert key_object is not None
|
||||
assert key_object.user_id == "test_user"
|
||||
assert key_object.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
assert key_object.models == ["gpt-3.5-turbo"]
|
||||
assert key_object.max_budget == litellm.max_ui_session_budget
|
||||
|
||||
|
||||
def test_get_key_object_from_ui_hash_key_invalid():
|
||||
"""Test getting key object from invalid UI hash key"""
|
||||
# Test with invalid token
|
||||
key_object = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key("invalid_token")
|
||||
assert key_object is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_internal_user_params_with_get_user_object(monkeypatch):
|
||||
"""Test that default_internal_user_params is used when creating a new user via get_user_object"""
|
||||
# Set up default_internal_user_params
|
||||
default_params = {
|
||||
"models": ["gpt-4", "claude-3-opus"],
|
||||
"max_budget": 200.0,
|
||||
"user_role": "internal_user",
|
||||
}
|
||||
monkeypatch.setattr(litellm, "default_internal_user_params", default_params)
|
||||
|
||||
# Mock the necessary dependencies
|
||||
mock_prisma_client = MagicMock()
|
||||
mock_db = AsyncMock()
|
||||
mock_prisma_client.db = mock_db
|
||||
|
||||
# Set up the user creation mock - create a complete user model that can be converted to a dict
|
||||
mock_user = MagicMock()
|
||||
mock_user.user_id = "new_test_user"
|
||||
mock_user.models = ["gpt-4", "claude-3-opus"]
|
||||
mock_user.max_budget = 200.0
|
||||
mock_user.user_role = "internal_user"
|
||||
mock_user.organization_memberships = []
|
||||
|
||||
# Make the mock model_dump or dict method return appropriate data
|
||||
mock_user.dict = lambda: {
|
||||
"user_id": "new_test_user",
|
||||
"models": ["gpt-4", "claude-3-opus"],
|
||||
"max_budget": 200.0,
|
||||
"user_role": "internal_user",
|
||||
"organization_memberships": [],
|
||||
}
|
||||
|
||||
# Setup the mock returns
|
||||
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
|
||||
mock_prisma_client.db.litellm_usertable.create = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Create a mock cache - use AsyncMock for async methods
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.async_get_cache = AsyncMock(return_value=None)
|
||||
mock_cache.async_set_cache = AsyncMock()
|
||||
|
||||
# Call get_user_object with user_id_upsert=True to trigger user creation
|
||||
try:
|
||||
user_obj = await get_user_object(
|
||||
user_id="new_test_user",
|
||||
prisma_client=mock_prisma_client,
|
||||
user_api_key_cache=mock_cache,
|
||||
user_id_upsert=True,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
except Exception as e:
|
||||
# this fails since the mock object is a MagicMock and not a LiteLLM_UserTable
|
||||
print(e)
|
||||
|
||||
# Verify the user was created with the default params
|
||||
mock_prisma_client.db.litellm_usertable.create.assert_called_once()
|
||||
creation_args = mock_prisma_client.db.litellm_usertable.create.call_args[1]["data"]
|
||||
|
||||
# Verify defaults were applied to the creation args
|
||||
assert "models" in creation_args
|
||||
assert creation_args["models"] == ["gpt-4", "claude-3-opus"]
|
||||
assert creation_args["max_budget"] == 200.0
|
||||
assert creation_args["user_role"] == "internal_user"
|
||||
|
||||
|
||||
# Vector Store Auth Check Tests
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"prisma_client,vector_store_registry,expected_result",
|
||||
[
|
||||
(None, MagicMock(), True), # No prisma client
|
||||
(MagicMock(), None, True), # No vector store registry
|
||||
(MagicMock(), MagicMock(), True), # No vector stores to run
|
||||
],
|
||||
)
|
||||
async def test_vector_store_access_check_early_returns(
|
||||
prisma_client, vector_store_registry, expected_result
|
||||
):
|
||||
"""Test vector_store_access_check returns True for early exit conditions"""
|
||||
request_body = {"messages": [{"role": "user", "content": "test"}]}
|
||||
|
||||
if vector_store_registry:
|
||||
vector_store_registry.get_vector_store_ids_to_run.return_value = None
|
||||
|
||||
with patch("litellm.proxy.proxy_server.prisma_client", prisma_client), patch(
|
||||
"litellm.vector_store_registry", vector_store_registry
|
||||
):
|
||||
result = await vector_store_access_check(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
valid_token=None,
|
||||
)
|
||||
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"object_permissions,vector_store_ids,should_raise,error_type",
|
||||
[
|
||||
(None, ["store-1"], False, None), # None permissions - should pass
|
||||
(
|
||||
{"vector_stores": []},
|
||||
["store-1"],
|
||||
False,
|
||||
None,
|
||||
), # Empty vector_stores - should pass (access to all)
|
||||
(
|
||||
{"vector_stores": ["store-1", "store-2"]},
|
||||
["store-1"],
|
||||
False,
|
||||
None,
|
||||
), # Has access
|
||||
(
|
||||
{"vector_stores": ["store-1", "store-2"]},
|
||||
["store-3"],
|
||||
True,
|
||||
ProxyErrorTypes.key_vector_store_access_denied,
|
||||
), # No access
|
||||
(
|
||||
{"vector_stores": ["store-1"]},
|
||||
["store-1", "store-3"],
|
||||
True,
|
||||
ProxyErrorTypes.team_vector_store_access_denied,
|
||||
), # Partial access
|
||||
],
|
||||
)
|
||||
def test_can_object_call_vector_stores_scenarios(
|
||||
object_permissions, vector_store_ids, should_raise, error_type
|
||||
):
|
||||
"""Test _can_object_call_vector_stores with various permission scenarios"""
|
||||
# Convert dict to object if not None
|
||||
if object_permissions is not None:
|
||||
mock_permissions = MagicMock()
|
||||
mock_permissions.vector_stores = object_permissions["vector_stores"]
|
||||
object_permissions = mock_permissions
|
||||
|
||||
object_type = (
|
||||
"key"
|
||||
if error_type == ProxyErrorTypes.key_vector_store_access_denied
|
||||
else "team"
|
||||
)
|
||||
|
||||
if should_raise:
|
||||
with pytest.raises(ProxyException) as exc_info:
|
||||
_can_object_call_vector_stores(
|
||||
object_type=object_type,
|
||||
vector_store_ids_to_run=vector_store_ids,
|
||||
object_permissions=object_permissions,
|
||||
)
|
||||
assert exc_info.value.type == error_type
|
||||
else:
|
||||
result = _can_object_call_vector_stores(
|
||||
object_type=object_type,
|
||||
vector_store_ids_to_run=vector_store_ids,
|
||||
object_permissions=object_permissions,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vector_store_access_check_with_permissions():
|
||||
"""Test vector_store_access_check with actual permission checking"""
|
||||
request_body = {"tools": [{"type": "function", "function": {"name": "test"}}]}
|
||||
|
||||
# Test with valid token that has access
|
||||
valid_token = UserAPIKeyAuth(
|
||||
token="test-token",
|
||||
object_permission_id="perm-123",
|
||||
models=["gpt-4"],
|
||||
max_budget=100.0,
|
||||
)
|
||||
|
||||
mock_prisma_client = MagicMock()
|
||||
mock_permissions = MagicMock()
|
||||
mock_permissions.vector_stores = ["store-1", "store-2"]
|
||||
mock_prisma_client.db.litellm_objectpermissiontable.find_unique = AsyncMock(
|
||||
return_value=mock_permissions
|
||||
)
|
||||
|
||||
mock_vector_store_registry = MagicMock()
|
||||
mock_vector_store_registry.get_vector_store_ids_to_run.return_value = ["store-1"]
|
||||
|
||||
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client), patch(
|
||||
"litellm.vector_store_registry", mock_vector_store_registry
|
||||
):
|
||||
result = await vector_store_access_check(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
valid_token=valid_token,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Test with denied access
|
||||
mock_vector_store_registry.get_vector_store_ids_to_run.return_value = ["store-3"]
|
||||
|
||||
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client), patch(
|
||||
"litellm.vector_store_registry", mock_vector_store_registry
|
||||
):
|
||||
with pytest.raises(ProxyException) as exc_info:
|
||||
await vector_store_access_check(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
valid_token=valid_token,
|
||||
)
|
||||
|
||||
assert exc_info.value.type == ProxyErrorTypes.key_vector_store_access_denied
|
||||
|
||||
|
||||
def test_can_object_call_model_with_alias():
|
||||
"""Test that can_object_call_model works with model aliases"""
|
||||
from litellm import Router
|
||||
from litellm.proxy.auth.auth_checks import _can_object_call_model
|
||||
|
||||
model = "[ip-approved] gpt-4o"
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
}
|
||||
],
|
||||
model_group_alias={
|
||||
"[ip-approved] gpt-4o": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"hidden": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result = _can_object_call_model(
|
||||
model=model,
|
||||
llm_router=llm_router,
|
||||
models=["gpt-3.5-turbo"],
|
||||
team_model_aliases=None,
|
||||
object_type="key",
|
||||
fallback_depth=0,
|
||||
)
|
||||
|
||||
print(result)
|
@@ -0,0 +1,154 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request, status
|
||||
from prisma import errors as prisma_errors
|
||||
from prisma.errors import (
|
||||
ClientNotConnectedError,
|
||||
DataError,
|
||||
ForeignKeyViolationError,
|
||||
HTTPClientClosedError,
|
||||
MissingRequiredValueError,
|
||||
PrismaError,
|
||||
RawQueryError,
|
||||
RecordNotFoundError,
|
||||
TableNotFoundError,
|
||||
UniqueViolationError,
|
||||
)
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import ProxyErrorTypes, ProxyException
|
||||
from litellm.proxy.auth.auth_exception_handler import UserAPIKeyAuthExceptionHandler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"prisma_error",
|
||||
[
|
||||
PrismaError(),
|
||||
DataError(data={"user_facing_error": {"meta": {"table": "test_table"}}}),
|
||||
UniqueViolationError(
|
||||
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||
),
|
||||
ForeignKeyViolationError(
|
||||
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||
),
|
||||
MissingRequiredValueError(
|
||||
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||
),
|
||||
RawQueryError(data={"user_facing_error": {"meta": {"table": "test_table"}}}),
|
||||
TableNotFoundError(
|
||||
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||
),
|
||||
RecordNotFoundError(
|
||||
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||
),
|
||||
HTTPClientClosedError(),
|
||||
ClientNotConnectedError(),
|
||||
],
|
||||
)
|
||||
async def test_handle_authentication_error_db_unavailable(prisma_error):
|
||||
handler = UserAPIKeyAuthExceptionHandler()
|
||||
|
||||
# Mock request and other dependencies
|
||||
mock_request = MagicMock()
|
||||
mock_request_data = {}
|
||||
mock_route = "/test"
|
||||
mock_span = None
|
||||
mock_api_key = "test-key"
|
||||
|
||||
# Test with DB connection error when requests are allowed
|
||||
with patch(
|
||||
"litellm.proxy.proxy_server.general_settings",
|
||||
{"allow_requests_on_db_unavailable": True},
|
||||
):
|
||||
result = await handler._handle_authentication_error(
|
||||
prisma_error,
|
||||
mock_request,
|
||||
mock_request_data,
|
||||
mock_route,
|
||||
mock_span,
|
||||
mock_api_key,
|
||||
)
|
||||
assert result.key_name == "failed-to-connect-to-db"
|
||||
assert result.token == "failed-to-connect-to-db"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_authentication_error_budget_exceeded():
|
||||
handler = UserAPIKeyAuthExceptionHandler()
|
||||
|
||||
# Mock request and other dependencies
|
||||
mock_request = MagicMock()
|
||||
mock_request_data = {}
|
||||
mock_route = "/test"
|
||||
mock_span = None
|
||||
mock_api_key = "test-key"
|
||||
|
||||
# Test with budget exceeded error
|
||||
with pytest.raises(ProxyException) as exc_info:
|
||||
from litellm.exceptions import BudgetExceededError
|
||||
|
||||
budget_error = BudgetExceededError(
|
||||
message="Budget exceeded", current_cost=100, max_budget=100
|
||||
)
|
||||
await handler._handle_authentication_error(
|
||||
budget_error,
|
||||
mock_request,
|
||||
mock_request_data,
|
||||
mock_route,
|
||||
mock_span,
|
||||
mock_api_key,
|
||||
)
|
||||
|
||||
assert exc_info.value.type == ProxyErrorTypes.budget_exceeded
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_passed_to_post_call_failure_hook():
|
||||
"""
|
||||
This route is used by proxy track_cost_callback's async_post_call_failure_hook to check if the route is an LLM route
|
||||
"""
|
||||
handler = UserAPIKeyAuthExceptionHandler()
|
||||
|
||||
# Mock request and other dependencies
|
||||
mock_request = MagicMock()
|
||||
mock_request_data = {}
|
||||
test_route = "/custom/route"
|
||||
mock_span = None
|
||||
mock_api_key = "test-key"
|
||||
|
||||
# Mock proxy_logging_obj.post_call_failure_hook
|
||||
with patch(
|
||||
"litellm.proxy.proxy_server.proxy_logging_obj.post_call_failure_hook",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_post_call_failure_hook:
|
||||
# Test with DB connection error
|
||||
with patch(
|
||||
"litellm.proxy.proxy_server.general_settings",
|
||||
{"allow_requests_on_db_unavailable": False},
|
||||
):
|
||||
try:
|
||||
await handler._handle_authentication_error(
|
||||
PrismaError(),
|
||||
mock_request,
|
||||
mock_request_data,
|
||||
test_route,
|
||||
mock_span,
|
||||
mock_api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
asyncio.sleep(1)
|
||||
# Verify post_call_failure_hook was called with the correct route
|
||||
mock_post_call_failure_hook.assert_called_once()
|
||||
call_args = mock_post_call_failure_hook.call_args[1]
|
||||
assert call_args["user_api_key_dict"].request_route == test_route
|
@@ -0,0 +1,849 @@
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.proxy._types import (
|
||||
JWTLiteLLMRoleMap,
|
||||
LiteLLM_JWTAuth,
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_UserTable,
|
||||
LitellmUserRoles,
|
||||
Member,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
)
|
||||
from litellm.proxy.auth.handle_jwt import JWTAuthManager, JWTHandler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_user_to_teams_user_already_in_team():
|
||||
"""Test that no action is taken when user is already in team"""
|
||||
# Setup test data
|
||||
user = LiteLLM_UserTable(user_id="test_user_1")
|
||||
team = LiteLLM_TeamTable(
|
||||
team_id="test_team_1",
|
||||
members_with_roles=[Member(user_id="test_user_1", role="user")],
|
||||
)
|
||||
|
||||
# Mock team_member_add to ensure it's not called
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.team_endpoints.team_member_add",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_add:
|
||||
await JWTAuthManager.map_user_to_teams(user_object=user, team_object=team)
|
||||
mock_add.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_user_to_teams_add_new_user():
|
||||
"""Test that new user is added to team"""
|
||||
# Setup test data
|
||||
user = LiteLLM_UserTable(user_id="test_user_1")
|
||||
team = LiteLLM_TeamTable(team_id="test_team_1", members_with_roles=[])
|
||||
|
||||
# Mock team_member_add
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.team_endpoints.team_member_add",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_add:
|
||||
await JWTAuthManager.map_user_to_teams(user_object=user, team_object=team)
|
||||
mock_add.assert_called_once()
|
||||
# Verify the correct data was passed to team_member_add
|
||||
call_args = mock_add.call_args[1]["data"]
|
||||
assert call_args.member.user_id == "test_user_1"
|
||||
assert call_args.member.role == "user"
|
||||
assert call_args.team_id == "test_team_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_user_to_teams_handles_already_in_team_exception():
|
||||
"""Test that team_member_already_in_team exception is handled gracefully"""
|
||||
# Setup test data
|
||||
user = LiteLLM_UserTable(user_id="test_user_1")
|
||||
team = LiteLLM_TeamTable(team_id="test_team_1", members_with_roles=[])
|
||||
|
||||
# Create a ProxyException with team_member_already_in_team error type
|
||||
already_in_team_exception = ProxyException(
|
||||
message="User test_user_1 already in team",
|
||||
type=ProxyErrorTypes.team_member_already_in_team,
|
||||
param="user_id",
|
||||
code="400",
|
||||
)
|
||||
|
||||
# Mock team_member_add to raise the exception
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.team_endpoints.team_member_add",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=already_in_team_exception,
|
||||
) as mock_add:
|
||||
with patch("litellm.proxy.auth.handle_jwt.verbose_proxy_logger") as mock_logger:
|
||||
# This should not raise an exception
|
||||
result = await JWTAuthManager.map_user_to_teams(
|
||||
user_object=user, team_object=team
|
||||
)
|
||||
|
||||
# Verify the method completed successfully
|
||||
assert result is None
|
||||
mock_add.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_user_to_teams_reraises_other_proxy_exceptions():
|
||||
"""Test that other ProxyException types are re-raised"""
|
||||
# Setup test data
|
||||
user = LiteLLM_UserTable(user_id="test_user_1")
|
||||
team = LiteLLM_TeamTable(team_id="test_team_1", members_with_roles=[])
|
||||
|
||||
# Create a ProxyException with a different error type
|
||||
other_exception = ProxyException(
|
||||
message="Some other error",
|
||||
type=ProxyErrorTypes.internal_server_error,
|
||||
param="some_param",
|
||||
code="500",
|
||||
)
|
||||
|
||||
# Mock team_member_add to raise the exception
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.team_endpoints.team_member_add",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=other_exception,
|
||||
) as mock_add:
|
||||
# This should re-raise the exception
|
||||
with pytest.raises(ProxyException) as exc_info:
|
||||
await JWTAuthManager.map_user_to_teams(user_object=user, team_object=team)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_user_to_teams_null_inputs():
|
||||
"""Test that method handles null inputs gracefully"""
|
||||
# Test with null user
|
||||
await JWTAuthManager.map_user_to_teams(
|
||||
user_object=None, team_object=LiteLLM_TeamTable(team_id="test_team_1")
|
||||
)
|
||||
|
||||
# Test with null team
|
||||
await JWTAuthManager.map_user_to_teams(
|
||||
user_object=LiteLLM_UserTable(user_id="test_user_1"), team_object=None
|
||||
)
|
||||
|
||||
# Test with both null
|
||||
await JWTAuthManager.map_user_to_teams(user_object=None, team_object=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_builder_proxy_admin_user_role():
|
||||
"""Test that is_proxy_admin is True when user_object.user_role is PROXY_ADMIN"""
|
||||
# Setup test data
|
||||
api_key = "test_jwt_token"
|
||||
request_data = {"model": "gpt-4"}
|
||||
general_settings = {"enforce_rbac": False}
|
||||
route = "/chat/completions"
|
||||
|
||||
# Create user object with PROXY_ADMIN role
|
||||
user_object = LiteLLM_UserTable(
|
||||
user_id="test_user_1", user_role=LitellmUserRoles.PROXY_ADMIN
|
||||
)
|
||||
|
||||
# Create mock JWT handler
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
|
||||
# Mock all the dependencies and method calls
|
||||
with patch.object(
|
||||
jwt_handler, "auth_jwt", new_callable=AsyncMock
|
||||
) as mock_auth_jwt, patch.object(
|
||||
JWTAuthManager, "check_rbac_role", new_callable=AsyncMock
|
||||
) as mock_check_rbac, patch.object(
|
||||
jwt_handler, "get_rbac_role", return_value=None
|
||||
) as mock_get_rbac, patch.object(
|
||||
jwt_handler, "get_scopes", return_value=[]
|
||||
) as mock_get_scopes, patch.object(
|
||||
jwt_handler, "get_object_id", return_value=None
|
||||
) as mock_get_object_id, patch.object(
|
||||
JWTAuthManager,
|
||||
"get_user_info",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("test_user_1", "test@example.com", True),
|
||||
) as mock_get_user_info, patch.object(
|
||||
jwt_handler, "get_org_id", return_value=None
|
||||
) as mock_get_org_id, patch.object(
|
||||
jwt_handler, "get_end_user_id", return_value=None
|
||||
) as mock_get_end_user_id, patch.object(
|
||||
JWTAuthManager, "check_admin_access", new_callable=AsyncMock, return_value=None
|
||||
) as mock_check_admin, patch.object(
|
||||
JWTAuthManager,
|
||||
"find_and_validate_specific_team_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(None, None),
|
||||
) as mock_find_team, patch.object(
|
||||
JWTAuthManager, "get_all_team_ids", return_value=set()
|
||||
) as mock_get_all_team_ids, patch.object(
|
||||
JWTAuthManager,
|
||||
"find_team_with_model_access",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(None, None),
|
||||
) as mock_find_team_access, patch.object(
|
||||
JWTAuthManager,
|
||||
"get_objects",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(user_object, None, None, None),
|
||||
) as mock_get_objects, patch.object(
|
||||
JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock
|
||||
) as mock_map_user, patch.object(
|
||||
JWTAuthManager, "validate_object_id", return_value=True
|
||||
) as mock_validate_object:
|
||||
# Set up the mock return values
|
||||
mock_auth_jwt.return_value = {"sub": "test_user_1", "scope": ""}
|
||||
|
||||
# Call the auth_builder method
|
||||
result = await JWTAuthManager.auth_builder(
|
||||
api_key=api_key,
|
||||
jwt_handler=jwt_handler,
|
||||
request_data=request_data,
|
||||
general_settings=general_settings,
|
||||
route=route,
|
||||
prisma_client=None,
|
||||
user_api_key_cache=None,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
# Verify that is_proxy_admin is True
|
||||
assert result["is_proxy_admin"] is True
|
||||
assert result["user_object"] == user_object
|
||||
assert result["user_id"] == "test_user_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_builder_non_proxy_admin_user_role():
|
||||
"""Test that is_proxy_admin is False when user_object.user_role is not PROXY_ADMIN"""
|
||||
# Setup test data
|
||||
api_key = "test_jwt_token"
|
||||
request_data = {"model": "gpt-4"}
|
||||
general_settings = {"enforce_rbac": False}
|
||||
route = "/chat/completions"
|
||||
|
||||
# Create user object with regular USER role
|
||||
user_object = LiteLLM_UserTable(
|
||||
user_id="test_user_1", user_role=LitellmUserRoles.INTERNAL_USER
|
||||
)
|
||||
|
||||
# Create mock JWT handler
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
|
||||
# Mock all the dependencies and method calls
|
||||
with patch.object(
|
||||
jwt_handler, "auth_jwt", new_callable=AsyncMock
|
||||
) as mock_auth_jwt, patch.object(
|
||||
JWTAuthManager, "check_rbac_role", new_callable=AsyncMock
|
||||
) as mock_check_rbac, patch.object(
|
||||
jwt_handler, "get_rbac_role", return_value=None
|
||||
) as mock_get_rbac, patch.object(
|
||||
jwt_handler, "get_scopes", return_value=[]
|
||||
) as mock_get_scopes, patch.object(
|
||||
jwt_handler, "get_object_id", return_value=None
|
||||
) as mock_get_object_id, patch.object(
|
||||
JWTAuthManager,
|
||||
"get_user_info",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("test_user_1", "test@example.com", True),
|
||||
) as mock_get_user_info, patch.object(
|
||||
jwt_handler, "get_org_id", return_value=None
|
||||
) as mock_get_org_id, patch.object(
|
||||
jwt_handler, "get_end_user_id", return_value=None
|
||||
) as mock_get_end_user_id, patch.object(
|
||||
JWTAuthManager, "check_admin_access", new_callable=AsyncMock, return_value=None
|
||||
) as mock_check_admin, patch.object(
|
||||
JWTAuthManager,
|
||||
"find_and_validate_specific_team_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(None, None),
|
||||
) as mock_find_team, patch.object(
|
||||
JWTAuthManager, "get_all_team_ids", return_value=set()
|
||||
) as mock_get_all_team_ids, patch.object(
|
||||
JWTAuthManager,
|
||||
"find_team_with_model_access",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(None, None),
|
||||
) as mock_find_team_access, patch.object(
|
||||
JWTAuthManager,
|
||||
"get_objects",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(user_object, None, None, None),
|
||||
) as mock_get_objects, patch.object(
|
||||
JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock
|
||||
) as mock_map_user, patch.object(
|
||||
JWTAuthManager, "validate_object_id", return_value=True
|
||||
) as mock_validate_object:
|
||||
# Set up the mock return values
|
||||
mock_auth_jwt.return_value = {"sub": "test_user_1", "scope": ""}
|
||||
|
||||
# Call the auth_builder method
|
||||
result = await JWTAuthManager.auth_builder(
|
||||
api_key=api_key,
|
||||
jwt_handler=jwt_handler,
|
||||
request_data=request_data,
|
||||
general_settings=general_settings,
|
||||
route=route,
|
||||
prisma_client=None,
|
||||
user_api_key_cache=None,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
# Verify that is_proxy_admin is False
|
||||
assert result["is_proxy_admin"] is False
|
||||
assert result["user_object"] == user_object
|
||||
assert result["user_id"] == "test_user_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_user_role_and_teams():
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create mock objects for required types
|
||||
mock_user_api_key_cache = MagicMock()
|
||||
mock_proxy_logging_obj = MagicMock()
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.update_environment(
|
||||
prisma_client=None,
|
||||
user_api_key_cache=mock_user_api_key_cache,
|
||||
litellm_jwtauth=LiteLLM_JWTAuth(
|
||||
jwt_litellm_role_map=[
|
||||
JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN)
|
||||
],
|
||||
roles_jwt_field="roles",
|
||||
team_ids_jwt_field="my_id_teams",
|
||||
sync_user_role_and_teams=True
|
||||
),
|
||||
)
|
||||
|
||||
token = {"roles": ["ADMIN"], "my_id_teams": ["team1", "team2"]}
|
||||
|
||||
user = LiteLLM_UserTable(user_id="u1", user_role=LitellmUserRoles.INTERNAL_USER.value, teams=["team2"])
|
||||
|
||||
prisma = AsyncMock()
|
||||
prisma.db.litellm_usertable.update = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.scim.scim_v2.patch_team_membership",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_patch:
|
||||
await JWTAuthManager.sync_user_role_and_teams(jwt_handler, token, user, prisma)
|
||||
|
||||
prisma.db.litellm_usertable.update.assert_called_once()
|
||||
mock_patch.assert_called_once()
|
||||
assert user.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||
assert set(user.teams) == {"team1", "team2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_jwt_role_to_litellm_role():
|
||||
"""Test JWT role mapping to LiteLLM roles with various patterns"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create mock objects for required types
|
||||
mock_user_api_key_cache = MagicMock()
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.update_environment(
|
||||
prisma_client=None,
|
||||
user_api_key_cache=mock_user_api_key_cache,
|
||||
litellm_jwtauth=LiteLLM_JWTAuth(
|
||||
jwt_litellm_role_map=[
|
||||
# Exact match
|
||||
JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN),
|
||||
# Wildcard patterns
|
||||
JWTLiteLLMRoleMap(jwt_role="user_*", litellm_role=LitellmUserRoles.INTERNAL_USER),
|
||||
JWTLiteLLMRoleMap(jwt_role="team_?", litellm_role=LitellmUserRoles.TEAM),
|
||||
JWTLiteLLMRoleMap(jwt_role="dev_[123]", litellm_role=LitellmUserRoles.INTERNAL_USER),
|
||||
],
|
||||
roles_jwt_field="roles"
|
||||
),
|
||||
)
|
||||
|
||||
# Test exact match
|
||||
token = {"roles": ["ADMIN"]}
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result == LitellmUserRoles.PROXY_ADMIN
|
||||
|
||||
# Test wildcard match with *
|
||||
token = {"roles": ["user_manager"]}
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result == LitellmUserRoles.INTERNAL_USER
|
||||
|
||||
token = {"roles": ["user_"]}
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result == LitellmUserRoles.INTERNAL_USER
|
||||
|
||||
# Test wildcard match with ?
|
||||
token = {"roles": ["team_1"]}
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result == LitellmUserRoles.TEAM
|
||||
|
||||
token = {"roles": ["team_a"]}
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result == LitellmUserRoles.TEAM
|
||||
|
||||
# Test character class match
|
||||
token = {"roles": ["dev_1"]}
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result == LitellmUserRoles.INTERNAL_USER
|
||||
|
||||
token = {"roles": ["dev_2"]}
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result == LitellmUserRoles.INTERNAL_USER
|
||||
|
||||
# Test no match
|
||||
token = {"roles": ["unknown_role"]}
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result is None
|
||||
|
||||
# Test multiple roles - should return first mapping match
|
||||
token = {"roles": ["user_test", "ADMIN"]}
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result == LitellmUserRoles.PROXY_ADMIN # ADMIN matches first mapping
|
||||
|
||||
# Test empty roles
|
||||
token = {"roles": []}
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result is None
|
||||
|
||||
# Test no roles field
|
||||
token = {}
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result is None
|
||||
|
||||
# Test no role mappings configured
|
||||
jwt_handler.litellm_jwtauth.jwt_litellm_role_map = None
|
||||
token = {"roles": ["ADMIN"]}
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result is None
|
||||
|
||||
# Test empty role mappings
|
||||
jwt_handler.litellm_jwtauth.jwt_litellm_role_map = []
|
||||
token = {"roles": ["ADMIN"]}
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result is None
|
||||
|
||||
# Test patterns that don't match character classes
|
||||
jwt_handler.litellm_jwtauth.jwt_litellm_role_map = [
|
||||
JWTLiteLLMRoleMap(jwt_role="dev_[123]", litellm_role=LitellmUserRoles.INTERNAL_USER),
|
||||
]
|
||||
token = {"roles": ["dev_4"]} # 4 is not in [123]
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result is None
|
||||
|
||||
# Test ? pattern that requires exactly one character
|
||||
jwt_handler.litellm_jwtauth.jwt_litellm_role_map = [
|
||||
JWTLiteLLMRoleMap(jwt_role="team_?", litellm_role=LitellmUserRoles.TEAM),
|
||||
]
|
||||
token = {"roles": ["team_12"]} # More than one character after underscore
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result is None
|
||||
|
||||
token = {"roles": ["team_"]} # No character after underscore
|
||||
result = jwt_handler.map_jwt_role_to_litellm_role(token)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_jwt_field_access():
|
||||
"""
|
||||
Test that all JWT fields support dot notation for nested access
|
||||
|
||||
This test verifies that:
|
||||
1. All JWT field methods can access nested values using dot notation
|
||||
2. Backward compatibility is maintained for flat field names
|
||||
3. Missing nested paths return appropriate defaults
|
||||
"""
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
|
||||
# Create JWT handler
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
# Test token with nested claims
|
||||
nested_token = {
|
||||
"user": {
|
||||
"sub": "u123",
|
||||
"email": "user@example.com"
|
||||
},
|
||||
"resource_access": {
|
||||
"my-client": {
|
||||
"roles": ["admin", "user"]
|
||||
}
|
||||
},
|
||||
"groups": ["team1", "team2"],
|
||||
"organization": {
|
||||
"id": "org456"
|
||||
},
|
||||
"profile": {
|
||||
"object_id": "obj789"
|
||||
},
|
||||
"customer": {
|
||||
"end_user_id": "customer123"
|
||||
},
|
||||
"tenant": {
|
||||
"team_id": "team456"
|
||||
}
|
||||
}
|
||||
|
||||
# Test flat token for backward compatibility
|
||||
flat_token = {
|
||||
"sub": "u123",
|
||||
"email": "user@example.com",
|
||||
"roles": ["admin", "user"],
|
||||
"groups": ["team1", "team2"],
|
||||
"org_id": "org456",
|
||||
"object_id": "obj789",
|
||||
"end_user_id": "customer123",
|
||||
"team_id": "team456"
|
||||
}
|
||||
|
||||
# Test 1: user_id_jwt_field with nested access
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(user_id_jwt_field="user.sub")
|
||||
assert jwt_handler.get_user_id(nested_token, None) == "u123"
|
||||
|
||||
# Test 1b: user_id_jwt_field with flat access (backward compatibility)
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(user_id_jwt_field="sub")
|
||||
assert jwt_handler.get_user_id(flat_token, None) == "u123"
|
||||
|
||||
# Test 2: user_email_jwt_field with nested access
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(user_email_jwt_field="user.email")
|
||||
assert jwt_handler.get_user_email(nested_token, None) == "user@example.com"
|
||||
|
||||
# Test 2b: user_email_jwt_field with flat access (backward compatibility)
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(user_email_jwt_field="email")
|
||||
assert jwt_handler.get_user_email(flat_token, None) == "user@example.com"
|
||||
|
||||
# Test 3: team_ids_jwt_field with nested access
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_ids_jwt_field="groups")
|
||||
assert jwt_handler.get_team_ids_from_jwt(nested_token) == ["team1", "team2"]
|
||||
|
||||
# Test 3b: team_ids_jwt_field with flat access (backward compatibility)
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_ids_jwt_field="groups")
|
||||
assert jwt_handler.get_team_ids_from_jwt(flat_token) == ["team1", "team2"]
|
||||
|
||||
# Test 4: org_id_jwt_field with nested access
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(org_id_jwt_field="organization.id")
|
||||
assert jwt_handler.get_org_id(nested_token, None) == "org456"
|
||||
|
||||
# Test 4b: org_id_jwt_field with flat access (backward compatibility)
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(org_id_jwt_field="org_id")
|
||||
assert jwt_handler.get_org_id(flat_token, None) == "org456"
|
||||
|
||||
# Test 5: object_id_jwt_field with nested access (requires role_mappings)
|
||||
from litellm.proxy._types import LitellmUserRoles, RoleMapping
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
object_id_jwt_field="profile.object_id",
|
||||
role_mappings=[RoleMapping(role="admin", internal_role=LitellmUserRoles.INTERNAL_USER)]
|
||||
)
|
||||
assert jwt_handler.get_object_id(nested_token, None) == "obj789"
|
||||
|
||||
# Test 5b: object_id_jwt_field with flat access (backward compatibility)
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
object_id_jwt_field="object_id",
|
||||
role_mappings=[RoleMapping(role="admin", internal_role=LitellmUserRoles.INTERNAL_USER)]
|
||||
)
|
||||
assert jwt_handler.get_object_id(flat_token, None) == "obj789"
|
||||
|
||||
# Test 6: end_user_id_jwt_field with nested access
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(end_user_id_jwt_field="customer.end_user_id")
|
||||
assert jwt_handler.get_end_user_id(nested_token, None) == "customer123"
|
||||
|
||||
# Test 6b: end_user_id_jwt_field with flat access (backward compatibility)
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(end_user_id_jwt_field="end_user_id")
|
||||
assert jwt_handler.get_end_user_id(flat_token, None) == "customer123"
|
||||
|
||||
# Test 7: team_id_jwt_field with nested access
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="tenant.team_id")
|
||||
assert jwt_handler.get_team_id(nested_token, None) == "team456"
|
||||
|
||||
# Test 7b: team_id_jwt_field with flat access (backward compatibility)
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="team_id")
|
||||
assert jwt_handler.get_team_id(flat_token, None) == "team456"
|
||||
|
||||
# Test 8: roles_jwt_field with deeply nested access (already supported, but testing)
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(roles_jwt_field="resource_access.my-client.roles")
|
||||
assert jwt_handler.get_jwt_role(nested_token, []) == ["admin", "user"]
|
||||
|
||||
# Test 9: user_roles_jwt_field with nested access (already supported, but testing)
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
user_roles_jwt_field="resource_access.my-client.roles",
|
||||
user_allowed_roles=["admin", "user"]
|
||||
)
|
||||
assert jwt_handler.get_user_roles(nested_token, []) == ["admin", "user"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_jwt_field_missing_paths():
|
||||
"""
|
||||
Test handling of missing nested paths in JWT tokens
|
||||
|
||||
This test verifies that:
|
||||
1. Missing nested paths return appropriate defaults
|
||||
2. Partial paths that exist but don't have the final key return defaults
|
||||
3. team_id_default fallback works with nested fields
|
||||
"""
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
|
||||
# Create JWT handler
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
# Test token with missing nested paths
|
||||
incomplete_token = {
|
||||
"user": {
|
||||
"name": "test user"
|
||||
# missing "sub" and "email"
|
||||
},
|
||||
"resource_access": {
|
||||
"other-client": {
|
||||
"roles": ["viewer"]
|
||||
}
|
||||
# missing "my-client"
|
||||
}
|
||||
# missing "organization", "profile", "customer", "tenant", "groups"
|
||||
}
|
||||
|
||||
# Test 1: Missing user.sub should return default
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(user_id_jwt_field="user.sub")
|
||||
assert jwt_handler.get_user_id(incomplete_token, "default_user") == "default_user"
|
||||
|
||||
# Test 2: Missing user.email should return default
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(user_email_jwt_field="user.email")
|
||||
assert jwt_handler.get_user_email(incomplete_token, "default@example.com") == "default@example.com"
|
||||
|
||||
# Test 3: Missing groups should return empty list
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_ids_jwt_field="groups")
|
||||
assert jwt_handler.get_team_ids_from_jwt(incomplete_token) == []
|
||||
|
||||
# Test 4: Missing organization.id should return default
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(org_id_jwt_field="organization.id")
|
||||
assert jwt_handler.get_org_id(incomplete_token, "default_org") == "default_org"
|
||||
|
||||
# Test 5: Missing profile.object_id should return default (requires role_mappings)
|
||||
from litellm.proxy._types import LitellmUserRoles, RoleMapping
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
object_id_jwt_field="profile.object_id",
|
||||
role_mappings=[RoleMapping(role="admin", internal_role=LitellmUserRoles.INTERNAL_USER)]
|
||||
)
|
||||
assert jwt_handler.get_object_id(incomplete_token, "default_obj") == "default_obj"
|
||||
|
||||
# Test 6: Missing customer.end_user_id should return default
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(end_user_id_jwt_field="customer.end_user_id")
|
||||
assert jwt_handler.get_end_user_id(incomplete_token, "default_customer") == "default_customer"
|
||||
|
||||
# Test 7: Missing tenant.team_id should use team_id_default fallback
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
team_id_jwt_field="tenant.team_id",
|
||||
team_id_default="fallback_team"
|
||||
)
|
||||
assert jwt_handler.get_team_id(incomplete_token, "default_team") == "fallback_team"
|
||||
|
||||
# Test 8: Missing resource_access.my-client.roles should return default
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(roles_jwt_field="resource_access.my-client.roles")
|
||||
assert jwt_handler.get_jwt_role(incomplete_token, ["default_role"]) == ["default_role"]
|
||||
|
||||
# Test 9: Missing nested user roles should return default
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
user_roles_jwt_field="resource_access.my-client.roles",
|
||||
user_allowed_roles=["admin", "user"]
|
||||
)
|
||||
assert jwt_handler.get_user_roles(incomplete_token, ["default_user_role"]) == ["default_user_role"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_prefix_handling_in_nested_fields():
|
||||
"""
|
||||
Test that metadata. prefix is properly handled in nested JWT field access
|
||||
|
||||
The get_nested_value function should remove metadata. prefix before traversing
|
||||
"""
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
|
||||
# Create JWT handler
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
# Test token with proper structure for metadata prefix removal
|
||||
token = {
|
||||
"user": {
|
||||
"email": "user@example.com" # This will be accessed when metadata.user.email is used
|
||||
},
|
||||
"sub": "u123"
|
||||
}
|
||||
|
||||
# Test 1: metadata.user.email should access user.email after prefix removal
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(user_email_jwt_field="metadata.user.email")
|
||||
# The get_nested_value function removes "metadata." prefix, so "metadata.user.email" becomes "user.email"
|
||||
assert jwt_handler.get_user_email(token, None) == "user@example.com"
|
||||
|
||||
# Test 2: user.sub should work normally without metadata prefix
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(user_id_jwt_field="sub")
|
||||
assert jwt_handler.get_user_id(token, None) == "u123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_team_with_model_access_model_group(monkeypatch):
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.router import Router
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {"model": "gpt-4o-mini"},
|
||||
"model_info": {"access_groups": ["test-group"]},
|
||||
}
|
||||
]
|
||||
)
|
||||
import sys
|
||||
import types
|
||||
|
||||
proxy_server_module = types.ModuleType("proxy_server")
|
||||
proxy_server_module.llm_router = router
|
||||
monkeypatch.setitem(sys.modules, "litellm.proxy.proxy_server", proxy_server_module)
|
||||
|
||||
team = LiteLLM_TeamTable(team_id="team-1", models=["test-group"])
|
||||
|
||||
async def mock_get_team_object(*args, **kwargs): # type: ignore
|
||||
return team
|
||||
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.auth.handle_jwt.get_team_object", mock_get_team_object
|
||||
)
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||
|
||||
team_id, team_obj = await JWTAuthManager.find_team_with_model_access(
|
||||
team_ids={"team-1"},
|
||||
requested_model="gpt-4o-mini",
|
||||
route="/chat/completions",
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=None,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
assert team_id == "team-1"
|
||||
assert team_obj.team_id == "team-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_builder_returns_team_membership_object():
|
||||
"""
|
||||
Test that auth_builder returns the team_membership_object when user is a member of a team.
|
||||
"""
|
||||
# Setup test data
|
||||
api_key = "test_jwt_token"
|
||||
request_data = {"model": "gpt-4"}
|
||||
general_settings = {"enforce_rbac": False}
|
||||
route = "/chat/completions"
|
||||
_team_id = "test_team_1"
|
||||
_user_id = "test_user_1"
|
||||
|
||||
# Create mock objects
|
||||
from litellm.proxy._types import LiteLLM_BudgetTable, LiteLLM_TeamMembership
|
||||
|
||||
mock_team_membership = LiteLLM_TeamMembership(
|
||||
user_id=_user_id,
|
||||
team_id=_team_id,
|
||||
budget_id="budget_123",
|
||||
spend=10.5,
|
||||
litellm_budget_table=LiteLLM_BudgetTable(
|
||||
budget_id="budget_123",
|
||||
rpm_limit=100,
|
||||
tpm_limit=5000
|
||||
)
|
||||
)
|
||||
|
||||
user_object = LiteLLM_UserTable(
|
||||
user_id=_user_id,
|
||||
user_role=LitellmUserRoles.INTERNAL_USER
|
||||
)
|
||||
|
||||
team_object = LiteLLM_TeamTable(team_id=_team_id)
|
||||
|
||||
# Create mock JWT handler
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
|
||||
# Mock all the dependencies and method calls
|
||||
with patch.object(
|
||||
jwt_handler, "auth_jwt", new_callable=AsyncMock
|
||||
) as mock_auth_jwt, patch.object(
|
||||
JWTAuthManager, "check_rbac_role", new_callable=AsyncMock
|
||||
) as mock_check_rbac, patch.object(
|
||||
jwt_handler, "get_rbac_role", return_value=None
|
||||
) as mock_get_rbac, patch.object(
|
||||
jwt_handler, "get_scopes", return_value=[]
|
||||
) as mock_get_scopes, patch.object(
|
||||
jwt_handler, "get_object_id", return_value=None
|
||||
) as mock_get_object_id, patch.object(
|
||||
JWTAuthManager,
|
||||
"get_user_info",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(_user_id, "test@example.com", True),
|
||||
) as mock_get_user_info, patch.object(
|
||||
jwt_handler, "get_org_id", return_value=None
|
||||
) as mock_get_org_id, patch.object(
|
||||
jwt_handler, "get_end_user_id", return_value=None
|
||||
) as mock_get_end_user_id, patch.object(
|
||||
JWTAuthManager, "check_admin_access", new_callable=AsyncMock, return_value=None
|
||||
) as mock_check_admin, patch.object(
|
||||
JWTAuthManager,
|
||||
"find_and_validate_specific_team_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(_team_id, team_object),
|
||||
) as mock_find_team, patch.object(
|
||||
JWTAuthManager, "get_all_team_ids", return_value=set()
|
||||
) as mock_get_all_team_ids, patch.object(
|
||||
JWTAuthManager,
|
||||
"find_team_with_model_access",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(None, None),
|
||||
) as mock_find_team_access, patch.object(
|
||||
JWTAuthManager,
|
||||
"get_objects",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(user_object, None, None, mock_team_membership),
|
||||
) as mock_get_objects, patch.object(
|
||||
JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock
|
||||
) as mock_map_user, patch.object(
|
||||
JWTAuthManager, "validate_object_id", return_value=True
|
||||
) as mock_validate_object, patch.object(
|
||||
JWTAuthManager, "sync_user_role_and_teams", new_callable=AsyncMock
|
||||
) as mock_sync_user:
|
||||
# Set up the mock return values
|
||||
mock_auth_jwt.return_value = {"sub": _user_id, "scope": ""}
|
||||
|
||||
# Call the auth_builder method
|
||||
result = await JWTAuthManager.auth_builder(
|
||||
api_key=api_key,
|
||||
jwt_handler=jwt_handler,
|
||||
request_data=request_data,
|
||||
general_settings=general_settings,
|
||||
route=route,
|
||||
prisma_client=None,
|
||||
user_api_key_cache=None,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
# Verify that team_membership_object is returned
|
||||
assert result["team_membership"] is not None, "team_membership should be present"
|
||||
assert result["team_membership"] == mock_team_membership, "team_membership should match the mock object"
|
||||
assert result["team_membership"].user_id == _user_id, "team_membership user_id should match"
|
||||
assert result["team_membership"].team_id == _team_id, "team_membership team_id should match"
|
||||
assert result["team_membership"].budget_id == "budget_123", "team_membership budget_id should match"
|
||||
assert result["team_membership"].spend == 10.5, "team_membership spend should match"
|
@@ -0,0 +1,29 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.proxy.auth.litellm_license import LicenseCheck
|
||||
|
||||
|
||||
def test_is_over_limit():
|
||||
license_check = LicenseCheck()
|
||||
license_check.airgapped_license_data = {"max_users": 100}
|
||||
assert license_check.is_over_limit(101) is True
|
||||
assert license_check.is_over_limit(100) is False
|
||||
assert license_check.is_over_limit(99) is False
|
||||
|
||||
license_check.airgapped_license_data = {}
|
||||
assert license_check.is_over_limit(101) is False
|
||||
assert license_check.is_over_limit(100) is False
|
||||
assert license_check.is_over_limit(99) is False
|
||||
|
||||
license_check.airgapped_license_data = None
|
||||
assert license_check.is_over_limit(101) is False
|
||||
assert license_check.is_over_limit(100) is False
|
||||
assert license_check.is_over_limit(99) is False
|
@@ -0,0 +1,64 @@
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.proxy._types import LiteLLM_TeamTable, LiteLLM_UserTable, Member
|
||||
from litellm.proxy.auth.handle_jwt import JWTAuthManager
|
||||
|
||||
|
||||
def test_get_team_models_for_all_models_and_team_only_models():
|
||||
from litellm.proxy.auth.model_checks import get_team_models
|
||||
|
||||
team_models = ["all-proxy-models", "team-only-model", "team-only-model-2"]
|
||||
proxy_model_list = ["model1", "model2", "model3"]
|
||||
model_access_groups = {}
|
||||
include_model_access_groups = False
|
||||
|
||||
result = get_team_models(
|
||||
team_models, proxy_model_list, model_access_groups, include_model_access_groups
|
||||
)
|
||||
combined_models = team_models + proxy_model_list
|
||||
assert set(result) == set(combined_models)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"key_models,team_models,proxy_model_list,model_list,expected",
|
||||
[
|
||||
(
|
||||
["anthropic/claude-3-haiku-20240307", "anthropic/claude-3-5-haiku-20241022"],
|
||||
[],
|
||||
[],
|
||||
[{"model_name": "anthropic/*", "litellm_params": {"model": "anthropic/*"}}],
|
||||
["anthropic/claude-3-haiku-20240307", "anthropic/claude-3-5-haiku-20241022"]
|
||||
),
|
||||
(
|
||||
[],
|
||||
["anthropic/claude-3-haiku-20240307", "anthropic/claude-3-5-haiku-20241022"],
|
||||
[],
|
||||
[{"model_name": "anthropic/*", "litellm_params": {"model": "anthropic/*"}}],
|
||||
["anthropic/claude-3-haiku-20240307", "anthropic/claude-3-5-haiku-20241022"]
|
||||
),
|
||||
(
|
||||
[],
|
||||
[],
|
||||
["anthropic/claude-3-haiku-20240307", "anthropic/claude-3-5-haiku-20241022"],
|
||||
[{"model_name": "anthropic/*", "litellm_params": {"model": "anthropic/*"}}],
|
||||
["anthropic/claude-3-haiku-20240307", "anthropic/claude-3-5-haiku-20241022"]
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_complete_model_list_order(key_models, team_models, proxy_model_list, model_list, expected):
|
||||
"""
|
||||
Test that get_complete_model_list preserves order
|
||||
"""
|
||||
from litellm.proxy.auth.model_checks import get_complete_model_list
|
||||
from litellm import Router
|
||||
|
||||
assert get_complete_model_list(
|
||||
proxy_model_list=proxy_model_list,
|
||||
key_models=key_models,
|
||||
team_models=team_models,
|
||||
user_model=None,
|
||||
infer_model_from_keys=False,
|
||||
llm_router=Router(model_list=model_list),
|
||||
) == expected
|
@@ -0,0 +1,242 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
|
||||
def create_mock_router(
|
||||
fallbacks=None, context_window_fallbacks=None, content_policy_fallbacks=None
|
||||
):
|
||||
"""Helper function to create a mock router with fallback configurations."""
|
||||
router = Mock()
|
||||
router.fallbacks = fallbacks or []
|
||||
router.context_window_fallbacks = context_window_fallbacks or []
|
||||
router.content_policy_fallbacks = content_policy_fallbacks or []
|
||||
return router
|
||||
|
||||
|
||||
def test_no_router_returns_empty_list():
|
||||
"""Test that None router returns empty list."""
|
||||
from litellm.proxy.auth.model_checks import get_all_fallbacks
|
||||
|
||||
result = get_all_fallbacks("claude-4-sonnet", llm_router=None)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_no_fallbacks_config_returns_empty_list():
|
||||
"""Test that empty fallbacks config returns empty list."""
|
||||
from litellm.proxy.auth.model_checks import get_all_fallbacks
|
||||
|
||||
router = create_mock_router(fallbacks=[])
|
||||
result = get_all_fallbacks("claude-4-sonnet", llm_router=router)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_model_with_fallbacks_returns_complete_list():
|
||||
"""Test that model with fallbacks returns complete fallback list."""
|
||||
from litellm.proxy.auth.model_checks import get_all_fallbacks
|
||||
|
||||
fallbacks_config = [
|
||||
{"claude-4-sonnet": ["bedrock-claude-sonnet-4", "google-claude-sonnet-4"]}
|
||||
]
|
||||
router = create_mock_router(fallbacks=fallbacks_config)
|
||||
|
||||
with patch(
|
||||
'litellm.proxy.auth.model_checks.get_fallback_model_group'
|
||||
) as mock_get_fallback:
|
||||
mock_get_fallback.return_value = (
|
||||
["bedrock-claude-sonnet-4", "google-claude-sonnet-4"], None
|
||||
)
|
||||
|
||||
result = get_all_fallbacks("claude-4-sonnet", llm_router=router)
|
||||
assert result == ["bedrock-claude-sonnet-4", "google-claude-sonnet-4"]
|
||||
|
||||
|
||||
def test_model_without_fallbacks_returns_empty_list():
|
||||
"""Test that model without fallbacks returns empty list."""
|
||||
from litellm.proxy.auth.model_checks import get_all_fallbacks
|
||||
|
||||
fallbacks_config = [
|
||||
{"claude-4-sonnet": ["bedrock-claude-sonnet-4", "google-claude-sonnet-4"]}
|
||||
]
|
||||
router = create_mock_router(fallbacks=fallbacks_config)
|
||||
|
||||
with patch(
|
||||
'litellm.proxy.auth.model_checks.get_fallback_model_group'
|
||||
) as mock_get_fallback:
|
||||
mock_get_fallback.return_value = (None, None)
|
||||
|
||||
result = get_all_fallbacks("bedrock-claude-sonnet-4", llm_router=router)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_general_fallback_type():
|
||||
"""Test general fallback type uses router.fallbacks."""
|
||||
from litellm.proxy.auth.model_checks import get_all_fallbacks
|
||||
|
||||
fallbacks_config = [
|
||||
{"claude-4-sonnet": ["bedrock-claude-sonnet-4"]}
|
||||
]
|
||||
router = create_mock_router(fallbacks=fallbacks_config)
|
||||
|
||||
with patch(
|
||||
'litellm.proxy.auth.model_checks.get_fallback_model_group'
|
||||
) as mock_get_fallback:
|
||||
mock_get_fallback.return_value = (["bedrock-claude-sonnet-4"], None)
|
||||
|
||||
result = get_all_fallbacks("claude-4-sonnet", llm_router=router, fallback_type="general")
|
||||
assert result == ["bedrock-claude-sonnet-4"]
|
||||
|
||||
# Verify it used the general fallbacks config
|
||||
mock_get_fallback.assert_called_once_with(
|
||||
fallbacks=fallbacks_config,
|
||||
model_group="claude-4-sonnet"
|
||||
)
|
||||
|
||||
|
||||
def test_context_window_fallback_type():
|
||||
"""Test context_window fallback type uses router.context_window_fallbacks."""
|
||||
from litellm.proxy.auth.model_checks import get_all_fallbacks
|
||||
|
||||
context_fallbacks_config = [
|
||||
{"gpt-4": ["gpt-3.5-turbo"]}
|
||||
]
|
||||
router = create_mock_router(context_window_fallbacks=context_fallbacks_config)
|
||||
|
||||
with patch(
|
||||
'litellm.proxy.auth.model_checks.get_fallback_model_group'
|
||||
) as mock_get_fallback:
|
||||
mock_get_fallback.return_value = (["gpt-3.5-turbo"], None)
|
||||
|
||||
result = get_all_fallbacks("gpt-4", llm_router=router, fallback_type="context_window")
|
||||
assert result == ["gpt-3.5-turbo"]
|
||||
|
||||
# Verify it used the context window fallbacks config
|
||||
mock_get_fallback.assert_called_once_with(
|
||||
fallbacks=context_fallbacks_config,
|
||||
model_group="gpt-4"
|
||||
)
|
||||
|
||||
|
||||
def test_content_policy_fallback_type():
|
||||
"""Test content_policy fallback type uses router.content_policy_fallbacks."""
|
||||
from litellm.proxy.auth.model_checks import get_all_fallbacks
|
||||
|
||||
content_fallbacks_config = [
|
||||
{"claude-4": ["claude-3"]}
|
||||
]
|
||||
router = create_mock_router(content_policy_fallbacks=content_fallbacks_config)
|
||||
|
||||
with patch(
|
||||
'litellm.proxy.auth.model_checks.get_fallback_model_group'
|
||||
) as mock_get_fallback:
|
||||
mock_get_fallback.return_value = (["claude-3"], None)
|
||||
|
||||
result = get_all_fallbacks("claude-4", llm_router=router, fallback_type="content_policy")
|
||||
assert result == ["claude-3"]
|
||||
|
||||
# Verify it used the content policy fallbacks config
|
||||
mock_get_fallback.assert_called_once_with(
|
||||
fallbacks=content_fallbacks_config,
|
||||
model_group="claude-4"
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_fallback_type_returns_empty_list():
|
||||
"""Test that invalid fallback type returns empty list and logs warning."""
|
||||
from litellm.proxy.auth.model_checks import get_all_fallbacks
|
||||
|
||||
router = create_mock_router(fallbacks=[])
|
||||
|
||||
with patch('litellm.proxy.auth.model_checks.verbose_proxy_logger') as mock_logger:
|
||||
result = get_all_fallbacks("claude-4-sonnet", llm_router=router, fallback_type="invalid")
|
||||
|
||||
assert result == []
|
||||
mock_logger.warning.assert_called_once_with("Unknown fallback_type: invalid")
|
||||
|
||||
|
||||
def test_exception_handling_returns_empty_list():
|
||||
"""Test that exceptions are handled gracefully and return empty list."""
|
||||
from litellm.proxy.auth.model_checks import get_all_fallbacks
|
||||
|
||||
router = create_mock_router(fallbacks=[{"claude-4-sonnet": ["fallback"]}])
|
||||
|
||||
with patch(
|
||||
'litellm.proxy.auth.model_checks.get_fallback_model_group'
|
||||
) as mock_get_fallback:
|
||||
mock_get_fallback.side_effect = Exception("Test exception")
|
||||
|
||||
with patch('litellm.proxy.auth.model_checks.verbose_proxy_logger') as mock_logger:
|
||||
result = get_all_fallbacks("claude-4-sonnet", llm_router=router)
|
||||
|
||||
assert result == []
|
||||
mock_logger.error.assert_called_once()
|
||||
error_call_args = mock_logger.error.call_args[0][0]
|
||||
assert "Error getting fallbacks for model claude-4-sonnet" in error_call_args
|
||||
|
||||
|
||||
def test_multiple_fallbacks_complete_list():
|
||||
"""Test model with multiple fallbacks returns the complete list."""
|
||||
from litellm.proxy.auth.model_checks import get_all_fallbacks
|
||||
|
||||
fallbacks_config = [
|
||||
{"gpt-4": ["gpt-4-turbo", "gpt-3.5-turbo", "claude-3-haiku"]}
|
||||
]
|
||||
router = create_mock_router(fallbacks=fallbacks_config)
|
||||
|
||||
with patch(
|
||||
'litellm.proxy.auth.model_checks.get_fallback_model_group'
|
||||
) as mock_get_fallback:
|
||||
mock_get_fallback.return_value = (["gpt-4-turbo", "gpt-3.5-turbo", "claude-3-haiku"], None)
|
||||
|
||||
result = get_all_fallbacks("gpt-4", llm_router=router)
|
||||
assert result == ["gpt-4-turbo", "gpt-3.5-turbo", "claude-3-haiku"]
|
||||
|
||||
|
||||
def test_wildcard_and_specific_fallbacks():
|
||||
"""Test fallbacks with wildcard and specific model configurations."""
|
||||
from litellm.proxy.auth.model_checks import get_all_fallbacks
|
||||
|
||||
fallbacks_config = [
|
||||
{"*": ["gpt-3.5-turbo"]},
|
||||
{"claude-4-sonnet": ["bedrock-claude-sonnet-4", "google-claude-sonnet-4"]}
|
||||
]
|
||||
router = create_mock_router(fallbacks=fallbacks_config)
|
||||
|
||||
with patch(
|
||||
'litellm.proxy.auth.model_checks.get_fallback_model_group'
|
||||
) as mock_get_fallback:
|
||||
# Test specific model fallbacks
|
||||
mock_get_fallback.return_value = (
|
||||
["bedrock-claude-sonnet-4", "google-claude-sonnet-4"], None
|
||||
)
|
||||
result = get_all_fallbacks("claude-4-sonnet", llm_router=router)
|
||||
assert result == ["bedrock-claude-sonnet-4", "google-claude-sonnet-4"]
|
||||
|
||||
# Test wildcard fallbacks
|
||||
mock_get_fallback.return_value = (["gpt-3.5-turbo"], 0)
|
||||
result = get_all_fallbacks("some-unknown-model", llm_router=router)
|
||||
assert result == ["gpt-3.5-turbo"]
|
||||
|
||||
|
||||
def test_default_fallback_type_is_general():
|
||||
"""Test that default fallback_type is 'general'."""
|
||||
from litellm.proxy.auth.model_checks import get_all_fallbacks
|
||||
|
||||
fallbacks_config = [
|
||||
{"claude-4-sonnet": ["bedrock-claude-sonnet-4"]}
|
||||
]
|
||||
router = create_mock_router(fallbacks=fallbacks_config)
|
||||
|
||||
with patch(
|
||||
'litellm.proxy.auth.model_checks.get_fallback_model_group'
|
||||
) as mock_get_fallback:
|
||||
mock_get_fallback.return_value = (["bedrock-claude-sonnet-4"], None)
|
||||
|
||||
# Call without specifying fallback_type
|
||||
result = get_all_fallbacks("claude-4-sonnet", llm_router=router)
|
||||
|
||||
# Should use general fallbacks (router.fallbacks)
|
||||
mock_get_fallback.assert_called_once_with(
|
||||
fallbacks=fallbacks_config,
|
||||
model_group="claude-4-sonnet"
|
||||
)
|
||||
assert result == ["bedrock-claude-sonnet-4"]
|
@@ -0,0 +1,207 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
|
||||
|
||||
def test_non_admin_config_update_route_rejected():
|
||||
"""Test that non-admin users are rejected when trying to call /config/update"""
|
||||
|
||||
# Create a non-admin user object
|
||||
user_obj = LiteLLM_UserTable(
|
||||
user_id="test_user",
|
||||
user_email="test@example.com",
|
||||
user_role=LitellmUserRoles.INTERNAL_USER.value, # Non-admin role
|
||||
)
|
||||
|
||||
# Create a non-admin user API key auth
|
||||
valid_token = UserAPIKeyAuth(
|
||||
user_id="test_user",
|
||||
user_role=LitellmUserRoles.INTERNAL_USER.value, # Non-admin role
|
||||
)
|
||||
|
||||
# Create a mock request
|
||||
request = MagicMock(spec=Request)
|
||||
request.query_params = {}
|
||||
|
||||
# Test that calling /config/update route raises HTTPException with 403 status
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
RouteChecks.non_proxy_admin_allowed_routes_check(
|
||||
user_obj=user_obj,
|
||||
_user_role=LitellmUserRoles.INTERNAL_USER.value,
|
||||
route="/config/update",
|
||||
request=request,
|
||||
valid_token=valid_token,
|
||||
request_data={},
|
||||
)
|
||||
|
||||
# Verify the exception is raised with the correct message
|
||||
assert (
|
||||
"Only proxy admin can be used to generate, delete, update info for new keys/users/teams"
|
||||
in str(exc_info.value)
|
||||
)
|
||||
assert "Route=/config/update" in str(exc_info.value)
|
||||
assert "Your role=internal_user" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_proxy_admin_viewer_config_update_route_rejected():
|
||||
"""Test that proxy admin viewer users are rejected when trying to call /config/update"""
|
||||
|
||||
# Create a proxy admin viewer user object (read-only admin)
|
||||
user_obj = LiteLLM_UserTable(
|
||||
user_id="viewer_user",
|
||||
user_email="viewer@example.com",
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value,
|
||||
)
|
||||
|
||||
# Create a proxy admin viewer user API key auth
|
||||
valid_token = UserAPIKeyAuth(
|
||||
user_id="viewer_user",
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value,
|
||||
)
|
||||
|
||||
# Create a mock request
|
||||
request = MagicMock(spec=Request)
|
||||
request.query_params = {}
|
||||
|
||||
# Test that calling /config/update route raises HTTPException with 403 status
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
RouteChecks.non_proxy_admin_allowed_routes_check(
|
||||
user_obj=user_obj,
|
||||
_user_role=LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value,
|
||||
route="/config/update",
|
||||
request=request,
|
||||
valid_token=valid_token,
|
||||
request_data={},
|
||||
)
|
||||
|
||||
# Verify the exception is HTTPException with 403 status
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "user not allowed to access this route" in str(exc_info.value.detail)
|
||||
assert "role= proxy_admin_viewer" in str(exc_info.value.detail)
|
||||
|
||||
|
||||
def test_virtual_key_allowed_routes_with_litellm_routes_member_name_allowed():
|
||||
"""Test that virtual key is allowed to call routes when allowed_routes contains LiteLLMRoutes member name"""
|
||||
|
||||
# Create a UserAPIKeyAuth with allowed_routes containing a LiteLLMRoutes member name
|
||||
valid_token = UserAPIKeyAuth(
|
||||
user_id="test_user",
|
||||
allowed_routes=["openai_routes"], # This is a member name in LiteLLMRoutes enum
|
||||
)
|
||||
|
||||
# Test that a route from the openai_routes group is allowed
|
||||
result = RouteChecks.is_virtual_key_allowed_to_call_route(
|
||||
route="/chat/completions", # This is in LiteLLMRoutes.openai_routes.value
|
||||
valid_token=valid_token,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_virtual_key_allowed_routes_with_litellm_routes_member_name_denied():
|
||||
"""Test that virtual key is denied when route is not in the allowed LiteLLMRoutes group"""
|
||||
|
||||
# Create a UserAPIKeyAuth with allowed_routes containing a LiteLLMRoutes member name
|
||||
valid_token = UserAPIKeyAuth(
|
||||
user_id="test_user",
|
||||
allowed_routes=["info_routes"], # This is a member name in LiteLLMRoutes enum
|
||||
)
|
||||
|
||||
# Test that a route NOT in the info_routes group raises an exception
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
RouteChecks.is_virtual_key_allowed_to_call_route(
|
||||
route="/chat/completions", # This is NOT in LiteLLMRoutes.info_routes.value
|
||||
valid_token=valid_token,
|
||||
)
|
||||
|
||||
# Verify the exception message
|
||||
assert "Virtual key is not allowed to call this route" in str(exc_info.value)
|
||||
assert "Only allowed to call routes: ['info_routes']" in str(exc_info.value)
|
||||
assert "Tried to call route: /chat/completions" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_virtual_key_allowed_routes_with_multiple_litellm_routes_member_names():
|
||||
"""Test that virtual key works with multiple LiteLLMRoutes member names in allowed_routes"""
|
||||
|
||||
# Create a UserAPIKeyAuth with multiple LiteLLMRoutes member names
|
||||
valid_token = UserAPIKeyAuth(
|
||||
user_id="test_user", allowed_routes=["openai_routes", "info_routes"]
|
||||
)
|
||||
|
||||
# Test that routes from both groups are allowed
|
||||
result1 = RouteChecks.is_virtual_key_allowed_to_call_route(
|
||||
route="/chat/completions", valid_token=valid_token # This is in openai_routes
|
||||
)
|
||||
|
||||
result2 = RouteChecks.is_virtual_key_allowed_to_call_route(
|
||||
route="/user/info", valid_token=valid_token # This is in info_routes
|
||||
)
|
||||
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
|
||||
|
||||
def test_virtual_key_allowed_routes_with_mixed_member_names_and_explicit_routes():
|
||||
"""Test that virtual key works with both LiteLLMRoutes member names and explicit routes"""
|
||||
|
||||
# Create a UserAPIKeyAuth with both member names and explicit routes
|
||||
valid_token = UserAPIKeyAuth(
|
||||
user_id="test_user",
|
||||
allowed_routes=[
|
||||
"info_routes",
|
||||
"/custom/route",
|
||||
], # Mix of member name and explicit route
|
||||
)
|
||||
|
||||
# Test that both info routes and explicit custom route are allowed
|
||||
result1 = RouteChecks.is_virtual_key_allowed_to_call_route(
|
||||
route="/user/info", valid_token=valid_token # This is in info_routes
|
||||
)
|
||||
|
||||
result2 = RouteChecks.is_virtual_key_allowed_to_call_route(
|
||||
route="/custom/route", valid_token=valid_token # This is explicitly listed
|
||||
)
|
||||
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
|
||||
|
||||
def test_virtual_key_allowed_routes_with_no_member_names_only_explicit():
|
||||
"""Test that virtual key works when allowed_routes contains only explicit routes (no member names)"""
|
||||
|
||||
# Create a UserAPIKeyAuth with only explicit routes (no LiteLLMRoutes member names)
|
||||
valid_token = UserAPIKeyAuth(
|
||||
user_id="test_user",
|
||||
allowed_routes=["/chat/completions", "/custom/route"], # Only explicit routes
|
||||
)
|
||||
|
||||
# Test that explicit routes are allowed
|
||||
result1 = RouteChecks.is_virtual_key_allowed_to_call_route(
|
||||
route="/chat/completions", valid_token=valid_token
|
||||
)
|
||||
|
||||
result2 = RouteChecks.is_virtual_key_allowed_to_call_route(
|
||||
route="/custom/route", valid_token=valid_token
|
||||
)
|
||||
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
|
||||
# Test that non-allowed route raises exception
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
RouteChecks.is_virtual_key_allowed_to_call_route(
|
||||
route="/user/info", valid_token=valid_token # Not in allowed routes
|
||||
)
|
||||
|
||||
assert "Virtual key is not allowed to call this route" in str(exc_info.value)
|
@@ -0,0 +1,58 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Tuple
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.proxy.auth.user_api_key_auth import get_api_key
|
||||
|
||||
|
||||
def test_get_api_key():
|
||||
bearer_token = "Bearer sk-12345678"
|
||||
api_key = "sk-12345678"
|
||||
passed_in_key = "Bearer sk-12345678"
|
||||
assert get_api_key(
|
||||
custom_litellm_key_header=None,
|
||||
api_key=bearer_token,
|
||||
azure_api_key_header=None,
|
||||
anthropic_api_key_header=None,
|
||||
google_ai_studio_api_key_header=None,
|
||||
azure_apim_header=None,
|
||||
pass_through_endpoints=None,
|
||||
route="",
|
||||
request=MagicMock(),
|
||||
) == (api_key, passed_in_key)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"custom_litellm_key_header, api_key, passed_in_key",
|
||||
[
|
||||
("Bearer sk-12345678", "sk-12345678", "Bearer sk-12345678"),
|
||||
("Basic sk-12345678", "sk-12345678", "Basic sk-12345678"),
|
||||
("bearer sk-12345678", "sk-12345678", "bearer sk-12345678"),
|
||||
("sk-12345678", "sk-12345678", "sk-12345678"),
|
||||
],
|
||||
)
|
||||
def test_get_api_key_with_custom_litellm_key_header(
|
||||
custom_litellm_key_header, api_key, passed_in_key
|
||||
):
|
||||
assert get_api_key(
|
||||
custom_litellm_key_header=custom_litellm_key_header,
|
||||
api_key=None,
|
||||
azure_api_key_header=None,
|
||||
anthropic_api_key_header=None,
|
||||
google_ai_studio_api_key_header=None,
|
||||
azure_apim_header=None,
|
||||
pass_through_endpoints=None,
|
||||
route="",
|
||||
request=MagicMock(),
|
||||
) == (api_key, passed_in_key)
|
Reference in New Issue
Block a user