Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,578 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
import json
|
||||
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.types.llms.openai import (
|
||||
IncompleteDetails,
|
||||
ResponseAPIUsage,
|
||||
ResponseCompletedEvent,
|
||||
ResponsesAPIResponse,
|
||||
ResponseTextConfig,
|
||||
)
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_responses_api_routing_with_previous_response_id():
|
||||
"""
|
||||
Test that when using a previous_response_id, the request is sent to the same model_id
|
||||
"""
|
||||
# Create a mock response that simulates Azure responses API
|
||||
mock_response_id = "resp_mock-resp-456"
|
||||
|
||||
mock_response_data = {
|
||||
"id": mock_response_id,
|
||||
"object": "response",
|
||||
"created_at": 1741476542,
|
||||
"status": "completed",
|
||||
"model": "azure/computer-use-preview",
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"id": "msg_123",
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "I'm doing well, thank you for asking!",
|
||||
"annotations": [],
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"parallel_tool_calls": True,
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
"output_tokens_details": {"reasoning_tokens": 0},
|
||||
},
|
||||
"text": {"format": {"type": "text"}},
|
||||
"error": None,
|
||||
"incomplete_details": None,
|
||||
"instructions": None,
|
||||
"metadata": {},
|
||||
"temperature": 1.0,
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"top_p": 1.0,
|
||||
"max_output_tokens": None,
|
||||
"previous_response_id": None,
|
||||
"reasoning": {"effort": None, "summary": None},
|
||||
"truncation": "disabled",
|
||||
"user": None,
|
||||
}
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, json_data, status_code):
|
||||
self._json_data = json_data
|
||||
self.status_code = status_code
|
||||
self.text = json.dumps(json_data)
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "azure-computer-use-preview",
|
||||
"litellm_params": {
|
||||
"model": "azure/computer-use-preview",
|
||||
"api_key": "mock-api-key",
|
||||
"api_version": "mock-api-version",
|
||||
"api_base": "https://mock-endpoint.openai.azure.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-computer-use-preview",
|
||||
"litellm_params": {
|
||||
"model": "azure/computer-use-preview-2",
|
||||
"api_key": "mock-api-key-2",
|
||||
"api_version": "mock-api-version-2",
|
||||
"api_base": "https://mock-endpoint-2.openai.azure.com",
|
||||
},
|
||||
},
|
||||
],
|
||||
optional_pre_call_checks=["responses_api_deployment_check"],
|
||||
)
|
||||
MODEL = "azure-computer-use-preview"
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_post:
|
||||
# Configure the mock to return our response
|
||||
mock_post.return_value = MockResponse(mock_response_data, 200)
|
||||
|
||||
# Make the initial request
|
||||
# litellm._turn_on_debug()
|
||||
response = await router.aresponses(
|
||||
model=MODEL,
|
||||
input="Hello, how are you?",
|
||||
truncation="auto",
|
||||
)
|
||||
print("RESPONSE", response)
|
||||
|
||||
# Store the model_id from the response
|
||||
expected_model_id = response._hidden_params["model_id"]
|
||||
response_id = response.id
|
||||
|
||||
print("Response ID=", response_id, "came from model_id=", expected_model_id)
|
||||
|
||||
# Make 10 other requests with previous_response_id, assert that they are sent to the same model_id
|
||||
for i in range(10):
|
||||
# Reset the mock for the next call
|
||||
mock_post.reset_mock()
|
||||
|
||||
# Set up the mock to return our response again
|
||||
mock_post.return_value = MockResponse(mock_response_data, 200)
|
||||
|
||||
response = await router.aresponses(
|
||||
model=MODEL,
|
||||
input=f"Follow-up question {i+1}",
|
||||
truncation="auto",
|
||||
previous_response_id=response_id,
|
||||
)
|
||||
|
||||
# Assert the model_id is preserved
|
||||
assert response._hidden_params["model_id"] == expected_model_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_routing_without_previous_response_id():
|
||||
"""
|
||||
Test that normal routing (load balancing) works when no previous_response_id is provided
|
||||
"""
|
||||
mock_response_data = {
|
||||
"id": "mock-resp-123",
|
||||
"object": "response",
|
||||
"created_at": 1741476542,
|
||||
"status": "completed",
|
||||
"model": "azure/computer-use-preview",
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"id": "msg_123",
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "output_text", "text": "Hello there!", "annotations": []}
|
||||
],
|
||||
}
|
||||
],
|
||||
"parallel_tool_calls": True,
|
||||
"usage": {
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 10,
|
||||
"total_tokens": 15,
|
||||
"output_tokens_details": {"reasoning_tokens": 0},
|
||||
},
|
||||
"text": {"format": {"type": "text"}},
|
||||
"error": None,
|
||||
"incomplete_details": None,
|
||||
"instructions": None,
|
||||
"metadata": {},
|
||||
"temperature": 1.0,
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"top_p": 1.0,
|
||||
"max_output_tokens": None,
|
||||
"previous_response_id": None,
|
||||
"reasoning": {"effort": None, "summary": None},
|
||||
"truncation": "disabled",
|
||||
"user": None,
|
||||
}
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, json_data, status_code):
|
||||
self._json_data = json_data
|
||||
self.status_code = status_code
|
||||
self.text = json.dumps(json_data)
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
# Create a router with two identical deployments to test load balancing
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "azure-computer-use-preview",
|
||||
"litellm_params": {
|
||||
"model": "azure/computer-use-preview",
|
||||
"api_key": "mock-api-key-1",
|
||||
"api_version": "mock-api-version",
|
||||
"api_base": "https://mock-endpoint-1.openai.azure.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-computer-use-preview",
|
||||
"litellm_params": {
|
||||
"model": "azure/computer-use-preview",
|
||||
"api_key": "mock-api-key-2",
|
||||
"api_version": "mock-api-version",
|
||||
"api_base": "https://mock-endpoint-2.openai.azure.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-computer-use-preview",
|
||||
"litellm_params": {
|
||||
"model": "azure/computer-use-preview",
|
||||
"api_key": "mock-api-key-3",
|
||||
"api_version": "mock-api-version",
|
||||
"api_base": "https://mock-endpoint-3.openai.azure.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-computer-use-preview",
|
||||
"litellm_params": {
|
||||
"model": "azure/computer-use-preview",
|
||||
"api_key": "mock-api-key-4",
|
||||
"api_version": "mock-api-version",
|
||||
"api_base": "https://mock-endpoint-4.openai.azure.com",
|
||||
},
|
||||
},
|
||||
],
|
||||
optional_pre_call_checks=["responses_api_deployment_check"],
|
||||
)
|
||||
|
||||
MODEL = "azure-computer-use-preview"
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_post:
|
||||
# Configure the mock to return our response
|
||||
mock_post.return_value = MockResponse(mock_response_data, 200)
|
||||
|
||||
# Make multiple requests and verify we're hitting different deployments
|
||||
used_model_ids = set()
|
||||
|
||||
for i in range(20):
|
||||
response = await router.aresponses(
|
||||
model=MODEL,
|
||||
input=f"Question {i}",
|
||||
truncation="auto",
|
||||
)
|
||||
|
||||
used_model_ids.add(response._hidden_params["model_id"])
|
||||
|
||||
# We should have used more than one model_id if load balancing is working
|
||||
assert (
|
||||
len(used_model_ids) > 1
|
||||
), "Load balancing isn't working, only one deployment was used"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_previous_response_id_not_in_cache():
|
||||
"""
|
||||
Test behavior when a previous_response_id is provided but not found in cache
|
||||
"""
|
||||
mock_response_data = {
|
||||
"id": "mock-resp-789",
|
||||
"object": "response",
|
||||
"created_at": 1741476542,
|
||||
"status": "completed",
|
||||
"model": "azure/computer-use-preview",
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"id": "msg_123",
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "Nice to meet you!",
|
||||
"annotations": [],
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"parallel_tool_calls": True,
|
||||
"usage": {
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 10,
|
||||
"total_tokens": 15,
|
||||
"output_tokens_details": {"reasoning_tokens": 0},
|
||||
},
|
||||
"text": {"format": {"type": "text"}},
|
||||
"error": None,
|
||||
"incomplete_details": None,
|
||||
"instructions": None,
|
||||
"metadata": {},
|
||||
"temperature": 1.0,
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"top_p": 1.0,
|
||||
"max_output_tokens": None,
|
||||
"previous_response_id": None,
|
||||
"reasoning": {"effort": None, "summary": None},
|
||||
"truncation": "disabled",
|
||||
"user": None,
|
||||
}
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, json_data, status_code):
|
||||
self._json_data = json_data
|
||||
self.status_code = status_code
|
||||
self.text = json.dumps(json_data)
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "azure-computer-use-preview",
|
||||
"litellm_params": {
|
||||
"model": "azure/computer-use-preview",
|
||||
"api_key": "mock-api-key-1",
|
||||
"api_version": "mock-api-version",
|
||||
"api_base": "https://mock-endpoint-1.openai.azure.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-computer-use-preview",
|
||||
"litellm_params": {
|
||||
"model": "azure/computer-use-preview",
|
||||
"api_key": "mock-api-key-2",
|
||||
"api_version": "mock-api-version",
|
||||
"api_base": "https://mock-endpoint-2.openai.azure.com",
|
||||
},
|
||||
},
|
||||
],
|
||||
optional_pre_call_checks=["responses_api_deployment_check"],
|
||||
)
|
||||
|
||||
MODEL = "azure-computer-use-preview"
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_post:
|
||||
# Configure the mock to return our response
|
||||
mock_post.return_value = MockResponse(mock_response_data, 200)
|
||||
|
||||
# Make a request with a non-existent previous_response_id
|
||||
response = await router.aresponses(
|
||||
model=MODEL,
|
||||
input="Hello, this is a test",
|
||||
truncation="auto",
|
||||
previous_response_id="non-existent-response-id",
|
||||
)
|
||||
|
||||
# Should still get a valid response
|
||||
assert response is not None
|
||||
assert response.id is not None
|
||||
|
||||
# Since the previous_response_id wasn't found, routing should work normally
|
||||
# We can't assert exactly which deployment was chosen, but we can verify the basics
|
||||
assert response._hidden_params["model_id"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_multiple_response_ids_routing():
|
||||
"""
|
||||
Test that different response IDs correctly route to their respective original deployments
|
||||
"""
|
||||
# Create two different mock responses for our two different deployments
|
||||
mock_response_data_1 = {
|
||||
"id": "mock-resp-deployment-1",
|
||||
"object": "response",
|
||||
"created_at": 1741476542,
|
||||
"status": "completed",
|
||||
"model": "azure/computer-use-preview",
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"id": "msg_123",
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "Response from deployment 1",
|
||||
"annotations": [],
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"parallel_tool_calls": True,
|
||||
"usage": {
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 10,
|
||||
"total_tokens": 15,
|
||||
"output_tokens_details": {"reasoning_tokens": 0},
|
||||
},
|
||||
"text": {"format": {"type": "text"}},
|
||||
"error": None,
|
||||
"incomplete_details": None,
|
||||
"instructions": None,
|
||||
"metadata": {},
|
||||
"temperature": 1.0,
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"top_p": 1.0,
|
||||
"max_output_tokens": None,
|
||||
"previous_response_id": None,
|
||||
"reasoning": {"effort": None, "summary": None},
|
||||
"truncation": "disabled",
|
||||
"user": None,
|
||||
}
|
||||
|
||||
mock_response_data_2 = {
|
||||
"id": "mock-resp-deployment-2",
|
||||
"object": "response",
|
||||
"created_at": 1741476542,
|
||||
"status": "completed",
|
||||
"model": "azure/computer-use-preview",
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"id": "msg_456",
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "Response from deployment 2",
|
||||
"annotations": [],
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"parallel_tool_calls": True,
|
||||
"usage": {
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 10,
|
||||
"total_tokens": 15,
|
||||
"output_tokens_details": {"reasoning_tokens": 0},
|
||||
},
|
||||
"text": {"format": {"type": "text"}},
|
||||
"error": None,
|
||||
"incomplete_details": None,
|
||||
"instructions": None,
|
||||
"metadata": {},
|
||||
"temperature": 1.0,
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"top_p": 1.0,
|
||||
"max_output_tokens": None,
|
||||
"previous_response_id": None,
|
||||
"reasoning": {"effort": None, "summary": None},
|
||||
"truncation": "disabled",
|
||||
"user": None,
|
||||
}
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, json_data, status_code):
|
||||
self._json_data = json_data
|
||||
self.status_code = status_code
|
||||
self.text = json.dumps(json_data)
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "azure-computer-use-preview",
|
||||
"litellm_params": {
|
||||
"model": "azure/computer-use-preview-1",
|
||||
"api_key": "mock-api-key-1",
|
||||
"api_version": "mock-api-version",
|
||||
"api_base": "https://mock-endpoint-1.openai.azure.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-computer-use-preview",
|
||||
"litellm_params": {
|
||||
"model": "azure/computer-use-preview-2",
|
||||
"api_key": "mock-api-key-2",
|
||||
"api_version": "mock-api-version",
|
||||
"api_base": "https://mock-endpoint-2.openai.azure.com",
|
||||
},
|
||||
},
|
||||
],
|
||||
optional_pre_call_checks=["responses_api_deployment_check"],
|
||||
)
|
||||
|
||||
MODEL = "azure-computer-use-preview"
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_post:
|
||||
# For the first request, return response from deployment 1
|
||||
mock_post.return_value = MockResponse(mock_response_data_1, 200)
|
||||
|
||||
# Make the first request to deployment 1
|
||||
response1 = await router.aresponses(
|
||||
model=MODEL,
|
||||
input="Request to deployment 1",
|
||||
truncation="auto",
|
||||
)
|
||||
|
||||
# Store details from first response
|
||||
model_id_1 = response1._hidden_params["model_id"]
|
||||
response_id_1 = response1.id
|
||||
|
||||
# For the second request, return response from deployment 2
|
||||
mock_post.return_value = MockResponse(mock_response_data_2, 200)
|
||||
|
||||
# Make the second request to deployment 2
|
||||
response2 = await router.aresponses(
|
||||
model=MODEL,
|
||||
input="Request to deployment 2",
|
||||
truncation="auto",
|
||||
)
|
||||
|
||||
# Store details from second response
|
||||
model_id_2 = response2._hidden_params["model_id"]
|
||||
response_id_2 = response2.id
|
||||
|
||||
# Wait for cache updates
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Now make follow-up requests using the previous response IDs
|
||||
|
||||
# First, reset mock
|
||||
mock_post.reset_mock()
|
||||
mock_post.return_value = MockResponse(mock_response_data_1, 200)
|
||||
|
||||
# Follow-up to response 1 should go to model_id_1
|
||||
follow_up_1 = await router.aresponses(
|
||||
model=MODEL,
|
||||
input="Follow up to deployment 1",
|
||||
truncation="auto",
|
||||
previous_response_id=response_id_1,
|
||||
)
|
||||
|
||||
# Verify it went to the correct deployment
|
||||
assert follow_up_1._hidden_params["model_id"] == model_id_1
|
||||
|
||||
# Reset mock again
|
||||
mock_post.reset_mock()
|
||||
mock_post.return_value = MockResponse(mock_response_data_2, 200)
|
||||
|
||||
# Follow-up to response 2 should go to model_id_2
|
||||
follow_up_2 = await router.aresponses(
|
||||
model=MODEL,
|
||||
input="Follow up to deployment 2",
|
||||
truncation="auto",
|
||||
previous_response_id=response_id_2,
|
||||
)
|
||||
|
||||
# Verify it went to the correct deployment
|
||||
assert follow_up_2._hidden_params["model_id"] == model_id_2
|
@@ -0,0 +1,189 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.router_utils.common_utils import filter_team_based_models
|
||||
|
||||
|
||||
class TestFilterTeamBasedModels:
|
||||
"""Test cases for filter_team_based_models function"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_deployments_with_teams(self) -> List[Dict]:
|
||||
"""Sample deployments where some have team_id and some don't"""
|
||||
return [
|
||||
{"model_info": {"id": "deployment-1", "team_id": "team-a"}},
|
||||
{"model_info": {"id": "deployment-2", "team_id": "team-b"}},
|
||||
{
|
||||
"model_info": {
|
||||
"id": "deployment-3"
|
||||
# No team_id - should always be included
|
||||
}
|
||||
},
|
||||
{"model_info": {"id": "deployment-4", "team_id": "team-a"}},
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_deployments_no_teams(self) -> List[Dict]:
|
||||
"""Sample deployments with no team_id restrictions"""
|
||||
return [
|
||||
{"model_info": {"id": "deployment-1"}},
|
||||
{"model_info": {"id": "deployment-2"}},
|
||||
]
|
||||
|
||||
def test_filter_team_based_models_none_request_kwargs(
|
||||
self, sample_deployments_with_teams
|
||||
):
|
||||
"""Test that when request_kwargs is None, all deployments are returned unchanged"""
|
||||
result = filter_team_based_models(sample_deployments_with_teams, None)
|
||||
assert result == sample_deployments_with_teams
|
||||
|
||||
def test_filter_team_based_models_empty_request_kwargs(
|
||||
self, sample_deployments_with_teams
|
||||
):
|
||||
"""Test with empty request_kwargs"""
|
||||
result = filter_team_based_models(sample_deployments_with_teams, {})
|
||||
# Should include all deployments since no team_id in request
|
||||
assert len(result) == 1
|
||||
|
||||
def test_filter_team_based_models_no_metadata(self, sample_deployments_with_teams):
|
||||
"""Test with request_kwargs that has no metadata"""
|
||||
request_kwargs = {"some_other_key": "value"}
|
||||
result = filter_team_based_models(sample_deployments_with_teams, request_kwargs)
|
||||
# Should include only non-team based deployments
|
||||
assert len(result) == 1
|
||||
|
||||
def test_filter_team_based_models_team_match_metadata(
|
||||
self, sample_deployments_with_teams
|
||||
):
|
||||
"""Test filtering when team_id is in metadata"""
|
||||
request_kwargs = {"metadata": {"user_api_key_team_id": "team-a"}}
|
||||
result = filter_team_based_models(sample_deployments_with_teams, request_kwargs)
|
||||
|
||||
# Should include:
|
||||
# - deployment-1 (team-a matches)
|
||||
# - deployment-3 (no team_id restriction)
|
||||
# - deployment-4 (team-a matches)
|
||||
# Should exclude:
|
||||
# - deployment-2 (team-b doesn't match)
|
||||
expected_ids = ["deployment-1", "deployment-3", "deployment-4"]
|
||||
result_ids = [d.get("model_info", {}).get("id") for d in result]
|
||||
assert sorted(result_ids) == sorted(expected_ids)
|
||||
|
||||
def test_filter_team_based_models_team_match_litellm_metadata(
|
||||
self, sample_deployments_with_teams
|
||||
):
|
||||
"""Test filtering when team_id is in litellm_metadata"""
|
||||
request_kwargs = {"litellm_metadata": {"user_api_key_team_id": "team-b"}}
|
||||
result = filter_team_based_models(sample_deployments_with_teams, request_kwargs)
|
||||
|
||||
# Should include:
|
||||
# - deployment-2 (team-b matches)
|
||||
# - deployment-3 (no team_id restriction)
|
||||
# Should exclude:
|
||||
# - deployment-1 (team-a doesn't match)
|
||||
# - deployment-4 (team-a doesn't match)
|
||||
expected_ids = ["deployment-2", "deployment-3"]
|
||||
result_ids = [d.get("model_info", {}).get("id") for d in result]
|
||||
assert sorted(result_ids) == sorted(expected_ids)
|
||||
|
||||
def test_filter_team_based_models_priority_metadata_over_litellm(
|
||||
self, sample_deployments_with_teams
|
||||
):
|
||||
"""Test that metadata.user_api_key_team_id takes priority over litellm_metadata.user_api_key_team_id"""
|
||||
request_kwargs = {
|
||||
"metadata": {
|
||||
"user_api_key_team_id": "team-a", # This should take priority
|
||||
"litellm_metadata": {"user_api_key_team_id": "team-b"},
|
||||
}
|
||||
}
|
||||
result = filter_team_based_models(sample_deployments_with_teams, request_kwargs)
|
||||
|
||||
# Should filter based on team-a (from metadata, not litellm_metadata)
|
||||
expected_ids = ["deployment-1", "deployment-3", "deployment-4"]
|
||||
result_ids = [d.get("model_info", {}).get("id") for d in result]
|
||||
assert sorted(result_ids) == sorted(expected_ids)
|
||||
|
||||
def test_filter_team_based_models_no_matching_team(
|
||||
self, sample_deployments_with_teams
|
||||
):
|
||||
"""Test when request team doesn't match any deployment teams"""
|
||||
request_kwargs = {"metadata": {"user_api_key_team_id": "team-nonexistent"}}
|
||||
result = filter_team_based_models(sample_deployments_with_teams, request_kwargs)
|
||||
|
||||
# Should only include deployment-3 (no team_id restriction)
|
||||
expected_ids = ["deployment-3"]
|
||||
result_ids = [d.get("model_info", {}).get("id") for d in result]
|
||||
assert result_ids == expected_ids
|
||||
|
||||
def test_filter_team_based_models_no_team_restrictions(
|
||||
self, sample_deployments_no_teams
|
||||
):
|
||||
"""Test with deployments that have no team restrictions"""
|
||||
request_kwargs = {"metadata": {"user_api_key_team_id": "any-team"}}
|
||||
result = filter_team_based_models(sample_deployments_no_teams, request_kwargs)
|
||||
|
||||
# Should include all deployments since none have team_id restrictions
|
||||
assert result == sample_deployments_no_teams
|
||||
|
||||
def test_filter_team_based_models_missing_model_info(self):
|
||||
"""Test with deployments missing model_info"""
|
||||
deployments = [
|
||||
{"model_info": {"id": "deployment-1", "team_id": "team-a"}},
|
||||
{
|
||||
# Missing model_info entirely
|
||||
},
|
||||
{
|
||||
"model_info": {
|
||||
# Missing id
|
||||
"team_id": "team-b"
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
request_kwargs = {"metadata": {"user_api_key_team_id": "team-a"}}
|
||||
result = filter_team_based_models(deployments, request_kwargs)
|
||||
|
||||
# Should handle missing model_info gracefully
|
||||
# deployment-1 should be included (team matches)
|
||||
# others should be included since they don't have proper team_id setup
|
||||
assert len(result) >= 1 # At least deployment-1 should be included
|
||||
|
||||
def test_filter_team_based_models_dict_input(self):
|
||||
"""Test with Dict input instead of List[Dict]"""
|
||||
# Note: Based on the function signature, it accepts Union[List[Dict], Dict]
|
||||
# But the implementation seems to expect List[Dict] for the filtering logic
|
||||
# This test documents the current behavior
|
||||
deployments_dict = {"key1": "value1", "key2": "value2"}
|
||||
|
||||
request_kwargs = {"metadata": {"user_api_key_team_id": "team-a"}}
|
||||
|
||||
# This should not crash, though the filtering logic won't apply to Dict input
|
||||
result = filter_team_based_models(deployments_dict, request_kwargs)
|
||||
# The function will likely return the dict unchanged or handle it differently
|
||||
assert result is not None
|
||||
|
||||
def test_filter_team_based_models_empty_deployments(self):
|
||||
"""Test with empty deployments list"""
|
||||
result = filter_team_based_models(
|
||||
[], {"metadata": {"user_api_key_team_id": "team-a"}}
|
||||
)
|
||||
assert result == []
|
||||
|
||||
def test_filter_team_based_models_none_team_id_in_deployment(self):
|
||||
"""Test with explicit None team_id in deployment"""
|
||||
deployments = [
|
||||
{"model_info": {"id": "deployment-1", "team_id": None}},
|
||||
{"model_info": {"id": "deployment-2", "team_id": "team-a"}},
|
||||
]
|
||||
|
||||
request_kwargs = {"metadata": {"user_api_key_team_id": "team-a"}}
|
||||
result = filter_team_based_models(deployments, request_kwargs)
|
||||
|
||||
# Both should be included:
|
||||
# - deployment-1 (None team_id is treated as no restriction)
|
||||
# - deployment-2 (team matches)
|
||||
expected_ids = ["deployment-1", "deployment-2"]
|
||||
result_ids = [d.get("model_info", {}).get("id") for d in result]
|
||||
assert sorted(result_ids) == sorted(expected_ids)
|
Reference in New Issue
Block a user