Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,183 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.router_strategy.auto_router.auto_router import AutoRouter
|
||||
|
||||
pytestmark = pytest.mark.skip(reason="Skipping auto router tests - beta feature")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_router_instance():
|
||||
"""Create a mock LiteLLM Router instance."""
|
||||
router = MagicMock()
|
||||
router.acompletion = AsyncMock()
|
||||
return router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_semantic_router():
|
||||
"""Create a mock SemanticRouter instance."""
|
||||
mock_router = MagicMock()
|
||||
mock_route = MagicMock()
|
||||
mock_route.name = "test-route"
|
||||
mock_router.routes = [mock_route]
|
||||
return mock_router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_route_choice():
|
||||
"""Create a mock RouteChoice instance."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.name = "test-model"
|
||||
return mock_choice
|
||||
|
||||
|
||||
class TestAutoRouter:
|
||||
"""Test class for AutoRouter methods."""
|
||||
|
||||
@patch('semantic_router.routers.SemanticRouter')
|
||||
def test_init(self, mock_semantic_router_class, mock_router_instance):
|
||||
"""Test that AutoRouter initializes correctly with all required parameters."""
|
||||
# Arrange
|
||||
mock_semantic_router_class.from_json.return_value = mock_semantic_router_class
|
||||
|
||||
model_name = "test-auto-router"
|
||||
router_config_path = "test/path/router.json"
|
||||
default_model = "gpt-4o-mini"
|
||||
embedding_model = "text-embedding-model"
|
||||
|
||||
# Act
|
||||
auto_router = AutoRouter(
|
||||
model_name=model_name,
|
||||
auto_router_config_path=router_config_path,
|
||||
default_model=default_model,
|
||||
embedding_model=embedding_model,
|
||||
litellm_router_instance=mock_router_instance,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert auto_router.auto_router_config_path == router_config_path
|
||||
assert auto_router.auto_sync_value == AutoRouter.DEFAULT_AUTO_SYNC_VALUE
|
||||
assert auto_router.default_model == default_model
|
||||
assert auto_router.embedding_model == embedding_model
|
||||
assert auto_router.litellm_router_instance == mock_router_instance
|
||||
assert auto_router.routelayer is None
|
||||
mock_semantic_router_class.from_json.assert_called_once_with(router_config_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('semantic_router.routers.SemanticRouter')
|
||||
@patch('litellm.router_strategy.auto_router.litellm_encoder.LiteLLMRouterEncoder')
|
||||
async def test_async_pre_routing_hook_with_route_choice(
|
||||
self,
|
||||
mock_encoder_class,
|
||||
mock_semantic_router_class,
|
||||
mock_router_instance,
|
||||
mock_route_choice
|
||||
):
|
||||
"""Test async_pre_routing_hook returns correct model when route is found."""
|
||||
# Arrange
|
||||
mock_loaded_router = MagicMock()
|
||||
mock_loaded_router.routes = ["route1", "route2"]
|
||||
mock_semantic_router_class.from_json.return_value = mock_loaded_router
|
||||
|
||||
mock_routelayer = MagicMock()
|
||||
mock_routelayer.return_value = mock_route_choice
|
||||
mock_semantic_router_class.return_value = mock_routelayer
|
||||
|
||||
auto_router = AutoRouter(
|
||||
model_name="test-auto-router",
|
||||
auto_router_config_path="test/path/router.json",
|
||||
default_model="gpt-4o-mini",
|
||||
embedding_model="text-embedding-model",
|
||||
litellm_router_instance=mock_router_instance,
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "test message"}]
|
||||
|
||||
# Act
|
||||
result = await auto_router.async_pre_routing_hook(
|
||||
model="test-model",
|
||||
request_kwargs={},
|
||||
messages=messages
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.model == "test-model" # Should use the route choice name
|
||||
assert result.messages == messages
|
||||
mock_routelayer.assert_called_once_with(text="test message")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('semantic_router.routers.SemanticRouter')
|
||||
@patch('litellm.router_strategy.auto_router.litellm_encoder.LiteLLMRouterEncoder')
|
||||
async def test_async_pre_routing_hook_with_list_route_choice(
|
||||
self,
|
||||
mock_encoder_class,
|
||||
mock_semantic_router_class,
|
||||
mock_router_instance,
|
||||
mock_route_choice
|
||||
):
|
||||
"""Test async_pre_routing_hook handles list of RouteChoice objects correctly."""
|
||||
# Arrange
|
||||
mock_loaded_router = MagicMock()
|
||||
mock_loaded_router.routes = ["route1", "route2"]
|
||||
mock_semantic_router_class.from_json.return_value = mock_loaded_router
|
||||
|
||||
mock_routelayer = MagicMock()
|
||||
mock_routelayer.return_value = [mock_route_choice] # Return list
|
||||
mock_semantic_router_class.return_value = mock_routelayer
|
||||
|
||||
auto_router = AutoRouter(
|
||||
model_name="test-auto-router",
|
||||
auto_router_config_path="test/path/router.json",
|
||||
default_model="gpt-4o-mini",
|
||||
embedding_model="text-embedding-model",
|
||||
litellm_router_instance=mock_router_instance,
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "test message"}]
|
||||
|
||||
# Act
|
||||
result = await auto_router.async_pre_routing_hook(
|
||||
model="test-model",
|
||||
request_kwargs={},
|
||||
messages=messages
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.model == "test-model"
|
||||
assert result.messages == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_pre_routing_hook_no_messages(self, mock_router_instance):
|
||||
"""Test async_pre_routing_hook returns None when no messages provided."""
|
||||
# Arrange
|
||||
with patch('semantic_router.routers.SemanticRouter'):
|
||||
auto_router = AutoRouter(
|
||||
model_name="test-auto-router",
|
||||
auto_router_config_path="test/path/router.json",
|
||||
default_model="gpt-4o-mini",
|
||||
embedding_model="text-embedding-model",
|
||||
litellm_router_instance=mock_router_instance,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await auto_router.async_pre_routing_hook(
|
||||
model="test-model",
|
||||
request_kwargs={},
|
||||
messages=None
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
@@ -0,0 +1,153 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.redis_cache import RedisPipelineIncrementOperation
|
||||
from litellm.router_strategy.base_routing_strategy import BaseRoutingStrategy
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dual_cache():
|
||||
dual_cache = MagicMock(spec=DualCache)
|
||||
dual_cache.in_memory_cache = MagicMock()
|
||||
dual_cache.redis_cache = MagicMock()
|
||||
|
||||
# Set up async method mocks to return coroutines
|
||||
future1: asyncio.Future[None] = asyncio.Future()
|
||||
future1.set_result(None)
|
||||
dual_cache.in_memory_cache.async_increment.return_value = future1
|
||||
|
||||
future2: asyncio.Future[None] = asyncio.Future()
|
||||
future2.set_result(None)
|
||||
dual_cache.redis_cache.async_increment_pipeline.return_value = future2
|
||||
|
||||
future3: asyncio.Future[None] = asyncio.Future()
|
||||
future3.set_result(None)
|
||||
dual_cache.in_memory_cache.async_set_cache.return_value = future3
|
||||
|
||||
# Fix for async_batch_get_cache
|
||||
batch_future: asyncio.Future[Dict[str, str]] = asyncio.Future()
|
||||
batch_future.set_result({"key1": "10.0", "key2": "20.0"})
|
||||
dual_cache.redis_cache.async_batch_get_cache.return_value = batch_future
|
||||
|
||||
return dual_cache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_strategy(mock_dual_cache):
|
||||
return BaseRoutingStrategy(
|
||||
dual_cache=mock_dual_cache,
|
||||
should_batch_redis_writes=False,
|
||||
default_sync_interval=1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increment_value_in_current_window(base_strategy, mock_dual_cache):
|
||||
# Test incrementing value in current window
|
||||
key = "test_key"
|
||||
value = 10.0
|
||||
ttl = 3600
|
||||
|
||||
await base_strategy._increment_value_in_current_window(key, value, ttl)
|
||||
|
||||
# Verify in-memory cache was incremented
|
||||
mock_dual_cache.in_memory_cache.async_increment.assert_called_once_with(
|
||||
key=key, value=value, ttl=ttl
|
||||
)
|
||||
|
||||
# Verify operation was queued for Redis
|
||||
assert len(base_strategy.redis_increment_operation_queue) == 1
|
||||
queued_op = base_strategy.redis_increment_operation_queue[0]
|
||||
assert isinstance(queued_op, dict)
|
||||
assert queued_op["key"] == key
|
||||
assert queued_op["increment_value"] == value
|
||||
assert queued_op["ttl"] == ttl
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_in_memory_increments_to_redis(base_strategy, mock_dual_cache):
|
||||
# Add some operations to the queue
|
||||
base_strategy.redis_increment_operation_queue = [
|
||||
RedisPipelineIncrementOperation(key="key1", increment_value=10, ttl=3600),
|
||||
RedisPipelineIncrementOperation(key="key2", increment_value=20, ttl=3600),
|
||||
]
|
||||
|
||||
await base_strategy._push_in_memory_increments_to_redis()
|
||||
|
||||
# Verify Redis pipeline was called
|
||||
mock_dual_cache.redis_cache.async_increment_pipeline.assert_called_once()
|
||||
# Verify queue was cleared
|
||||
assert len(base_strategy.redis_increment_operation_queue) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_in_memory_spend_with_redis(base_strategy, mock_dual_cache):
|
||||
from litellm.types.caching import RedisPipelineIncrementOperation
|
||||
|
||||
# Setup test data
|
||||
base_strategy.in_memory_keys_to_update = {"key1"}
|
||||
base_strategy.redis_increment_operation_queue = [
|
||||
RedisPipelineIncrementOperation(key="key1", increment_value=10, ttl=3600),
|
||||
]
|
||||
|
||||
# Mock the in-memory cache batch get responses for before snapshot
|
||||
in_memory_before_future: asyncio.Future[List[str]] = asyncio.Future()
|
||||
in_memory_before_future.set_result(["5.0"]) # Initial values
|
||||
mock_dual_cache.in_memory_cache.async_batch_get_cache.return_value = (
|
||||
in_memory_before_future
|
||||
)
|
||||
|
||||
# Mock Redis batch get response
|
||||
redis_future: asyncio.Future[Dict[str, str]] = asyncio.Future()
|
||||
redis_future.set_result([15.0]) # Redis values
|
||||
mock_dual_cache.redis_cache.async_increment_pipeline.return_value = redis_future
|
||||
|
||||
# Mock in-memory get for after snapshot
|
||||
in_memory_after_future: asyncio.Future[Optional[str]] = asyncio.Future()
|
||||
in_memory_after_future.set_result("8.0") # Value after potential updates
|
||||
mock_dual_cache.in_memory_cache.async_get_cache.return_value = (
|
||||
in_memory_after_future
|
||||
)
|
||||
|
||||
await base_strategy._sync_in_memory_spend_with_redis()
|
||||
|
||||
# Verify the final merged values
|
||||
set_cache_calls = mock_dual_cache.in_memory_cache.async_set_cache.call_args_list
|
||||
print(f"set_cache_calls: {set_cache_calls}")
|
||||
assert any(
|
||||
call.kwargs["key"] == "key1" and float(call.kwargs["value"]) == 18.0
|
||||
for call in set_cache_calls
|
||||
)
|
||||
|
||||
# Verify cache keys still exist
|
||||
assert len(base_strategy.in_memory_keys_to_update) == 1
|
||||
|
||||
|
||||
def test_cache_keys_management(base_strategy):
|
||||
# Test adding and getting cache keys
|
||||
base_strategy.add_to_in_memory_keys_to_update("key1")
|
||||
base_strategy.add_to_in_memory_keys_to_update("key2")
|
||||
base_strategy.add_to_in_memory_keys_to_update("key1") # Duplicate should be ignored
|
||||
|
||||
cache_keys = base_strategy.get_in_memory_keys_to_update()
|
||||
assert len(cache_keys) == 2
|
||||
assert "key1" in cache_keys
|
||||
assert "key2" in cache_keys
|
||||
|
||||
# Test resetting cache keys
|
||||
base_strategy.reset_in_memory_keys_to_update()
|
||||
assert len(base_strategy.get_in_memory_keys_to_update()) == 0
|
Reference in New Issue
Block a user