Added LiteLLM to the stack

This commit is contained in:
2025-08-18 09:40:50 +00:00
parent 0648c1968c
commit d220b04e32
2682 changed files with 533609 additions and 1 deletions

View File

@@ -0,0 +1,198 @@
import os
import sys
from unittest.mock import MagicMock, patch, AsyncMock
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm.caching.azure_blob_cache import AzureBlobCache
@pytest.fixture
def mock_azure_dependencies():
"""Mock all Azure dependencies to avoid requiring actual Azure credentials"""
# Create mock container clients that will be assigned to the cache instance
mock_container_client = MagicMock()
mock_async_container_client = AsyncMock()
# Mock credentials
mock_credential = MagicMock()
mock_async_credential = AsyncMock()
# Create mock blob service clients that return the container clients
mock_blob_service_client = MagicMock()
mock_blob_service_client.get_container_client.return_value = mock_container_client
mock_async_blob_service_client = AsyncMock()
# For AsyncMock, we need to make get_container_client return the mock directly, not a coroutine
mock_async_blob_service_client.get_container_client = MagicMock(return_value=mock_async_container_client)
# Patch Azure dependencies at their source locations
with patch("azure.identity.DefaultAzureCredential", return_value=mock_credential), \
patch("azure.identity.aio.DefaultAzureCredential", return_value=mock_async_credential), \
patch("azure.storage.blob.BlobServiceClient", return_value=mock_blob_service_client), \
patch("azure.storage.blob.aio.BlobServiceClient", return_value=mock_async_blob_service_client), \
patch("azure.core.exceptions.ResourceExistsError"):
yield {
"container_client": mock_container_client,
"async_container_client": mock_async_container_client,
"blob_service_client": mock_blob_service_client,
"async_blob_service_client": mock_async_blob_service_client,
"credential": mock_credential,
"async_credential": mock_async_credential,
}
@pytest.mark.asyncio
async def test_blob_cache_async_get_cache(mock_azure_dependencies):
"""Test async_get_cache method with mocked Azure dependencies"""
# Create cache instance (this will use the mocked dependencies)
cache = AzureBlobCache("https://my-test-host", "test-container")
# Mock the download_blob response
mock_blob = AsyncMock()
mock_blob.readall.return_value = b'{"test_key": "test_value"}'
# Set up the mock for download_blob on the actual container client instance
cache.async_container_client.download_blob.return_value = mock_blob
# Test successful cache retrieval
result = await cache.async_get_cache("test_key")
# Verify the call was made correctly
cache.async_container_client.download_blob.assert_called_once_with("test_key")
mock_blob.readall.assert_called_once()
# Check the result
assert result == {"test_key": "test_value"}
@pytest.mark.asyncio
async def test_blob_cache_async_get_cache_not_found(mock_azure_dependencies):
"""Test async_get_cache method when blob is not found"""
# Import the exception inside the test to avoid import issues
from azure.core.exceptions import ResourceNotFoundError
cache = AzureBlobCache("https://my-test-host", "test-container")
# Mock ResourceNotFoundError
cache.async_container_client.download_blob.side_effect = ResourceNotFoundError("Blob not found")
# Test cache miss
result = await cache.async_get_cache("nonexistent_key")
# Verify the call was made and result is None
cache.async_container_client.download_blob.assert_called_once_with("nonexistent_key")
assert result is None
@pytest.mark.asyncio
async def test_blob_cache_async_set_cache(mock_azure_dependencies):
"""Test async_set_cache method with mocked Azure dependencies"""
cache = AzureBlobCache("https://my-test-host", "test-container")
test_value = {"key": "value", "number": 42}
# Test setting cache
await cache.async_set_cache("test_key", test_value)
# Verify the call was made correctly
cache.async_container_client.upload_blob.assert_called_once_with(
"test_key",
'{"key": "value", "number": 42}',
overwrite=True
)
def test_blob_cache_sync_get_cache(mock_azure_dependencies):
"""Test sync get_cache method with mocked Azure dependencies"""
cache = AzureBlobCache("https://my-test-host", "test-container")
# Mock the download_blob response
mock_blob = MagicMock()
mock_blob.readall.return_value = b'{"sync_key": "sync_value"}'
cache.container_client.download_blob.return_value = mock_blob
# Test successful cache retrieval
result = cache.get_cache("sync_key")
# Verify the call was made correctly
cache.container_client.download_blob.assert_called_once_with("sync_key")
mock_blob.readall.assert_called_once()
# Check the result
assert result == {"sync_key": "sync_value"}
def test_blob_cache_sync_set_cache(mock_azure_dependencies):
"""Test sync set_cache method with mocked Azure dependencies"""
cache = AzureBlobCache("https://my-test-host", "test-container")
test_value = {"sync_key": "sync_value", "number": 123}
# Test setting cache
cache.set_cache("sync_test_key", test_value)
# Verify the call was made correctly
cache.container_client.upload_blob.assert_called_once_with(
"sync_test_key",
'{"sync_key": "sync_value", "number": 123}'
)
def test_blob_cache_sync_get_cache_not_found(mock_azure_dependencies):
"""Test sync get_cache method when blob is not found"""
from azure.core.exceptions import ResourceNotFoundError
cache = AzureBlobCache("https://my-test-host", "test-container")
# Mock ResourceNotFoundError
cache.container_client.download_blob.side_effect = ResourceNotFoundError("Blob not found")
# Test cache miss
result = cache.get_cache("nonexistent_key")
# Verify the call was made and result is None
cache.container_client.download_blob.assert_called_once_with("nonexistent_key")
assert result is None
@pytest.mark.asyncio
async def test_blob_cache_async_set_cache_pipeline(mock_azure_dependencies):
"""Test async_set_cache_pipeline method with mocked Azure dependencies"""
cache = AzureBlobCache("https://my-test-host", "test-container")
# Test data for pipeline
cache_list = [
("key1", {"value": "data1"}),
("key2", {"value": "data2"}),
("key3", {"value": "data3"}),
]
# Test pipeline cache setting
await cache.async_set_cache_pipeline(cache_list)
# Verify all calls were made correctly
expected_calls = [
(("key1", '{"value": "data1"}'), {"overwrite": True}),
(("key2", '{"value": "data2"}'), {"overwrite": True}),
(("key3", '{"value": "data3"}'), {"overwrite": True}),
]
assert cache.async_container_client.upload_blob.call_count == 3
for expected_call in expected_calls:
cache.async_container_client.upload_blob.assert_any_call(*expected_call[0], **expected_call[1])

View File

@@ -0,0 +1,54 @@
import asyncio
import json
import os
import sys
import time
from unittest.mock import MagicMock, patch
import httpx
import pytest
import respx
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
from litellm.caching.caching_handler import LLMCachingHandler
@pytest.mark.asyncio
async def test_process_async_embedding_cached_response():
llm_caching_handler = LLMCachingHandler(
original_function=MagicMock(),
request_kwargs={},
start_time=datetime.now(),
)
args = {
"cached_result": [
{
"embedding": [-0.025122925639152527, -0.019487135112285614],
"index": 0,
"object": "embedding",
}
]
}
mock_logging_obj = MagicMock()
mock_logging_obj.async_success_handler = AsyncMock()
response, cache_hit = llm_caching_handler._process_async_embedding_cached_response(
final_embedding_cached_response=None,
cached_result=args["cached_result"],
kwargs={"model": "text-embedding-ada-002", "input": "test"},
logging_obj=mock_logging_obj,
start_time=datetime.now(),
model="text-embedding-ada-002",
)
assert cache_hit
print(f"response: {response}")
assert len(response.data) == 1

View File

@@ -0,0 +1,36 @@
import os
import sys
from unittest.mock import MagicMock, AsyncMock, patch
import pytest
sys.path.insert(0, os.path.abspath("../../.."))
from litellm.caching.gcs_cache import GCSCache
@pytest.fixture
def mock_gcs_dependencies():
"""Mock httpx clients and GCS auth"""
mock_sync_client = MagicMock()
mock_async_client = AsyncMock()
with patch("litellm.caching.gcs_cache._get_httpx_client", return_value=mock_sync_client), \
patch("litellm.caching.gcs_cache.get_async_httpx_client", return_value=mock_async_client), \
patch("litellm.caching.gcs_cache.GCSBucketBase.sync_construct_request_headers", return_value={}):
yield {
"sync_client": mock_sync_client,
"async_client": mock_async_client,
}
@pytest.mark.asyncio
async def test_gcs_cache_async_set_and_get(mock_gcs_dependencies):
cache = GCSCache(bucket_name="test-bucket")
await cache.async_set_cache("key", {"foo": "bar"})
mock_gcs_dependencies["async_client"].post.assert_called_once()
mock_gcs_dependencies["async_client"].get.return_value.status_code = 200
mock_gcs_dependencies["async_client"].get.return_value.text = "{\"foo\": \"bar\"}"
result = await cache.async_get_cache("key")
assert result == {"foo": "bar"}

View File

@@ -0,0 +1,90 @@
import asyncio
import json
import os
import sys
import time
from unittest.mock import MagicMock, patch
import httpx
import pytest
import respx
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from unittest.mock import AsyncMock
from litellm.caching.in_memory_cache import InMemoryCache
def test_in_memory_openai_obj_cache():
from openai import OpenAI
openai_obj = OpenAI(api_key="my-fake-key")
in_memory_cache = InMemoryCache()
in_memory_cache.set_cache(key="my-fake-key", value=openai_obj)
cached_obj = in_memory_cache.get_cache(key="my-fake-key")
assert cached_obj is not None
assert cached_obj == openai_obj
def test_in_memory_cache_max_size_per_item():
"""
Test that the cache will not store items larger than the max size per item
"""
in_memory_cache = InMemoryCache(max_size_per_item=100)
result = in_memory_cache.check_value_size("a" * 100000000)
assert result is False
def test_in_memory_cache_ttl():
"""
Check that
- if ttl is not set, it will be set to default ttl
- if object expires, the ttl is also removed
"""
in_memory_cache = InMemoryCache()
in_memory_cache.set_cache(key="my-fake-key", value="my-fake-value", ttl=10)
initial_ttl_time = in_memory_cache.ttl_dict["my-fake-key"]
assert initial_ttl_time is not None
in_memory_cache.set_cache(key="my-fake-key", value="my-fake-value-2", ttl=10)
new_ttl_time = in_memory_cache.ttl_dict["my-fake-key"]
assert new_ttl_time == initial_ttl_time # ttl should not be updated
## On object expiration, the ttl should be removed
in_memory_cache.set_cache(key="new-fake-key", value="new-fake-value", ttl=1)
new_ttl_time = in_memory_cache.ttl_dict["new-fake-key"]
assert new_ttl_time is not None
time.sleep(1)
cached_obj = in_memory_cache.get_cache(key="new-fake-key")
new_ttl_time = in_memory_cache.ttl_dict.get("new-fake-key")
assert new_ttl_time is None
def test_in_memory_cache_ttl_allow_override():
"""
Check that
- if ttl is not set, it will be set to default ttl
- if object expires, the ttl is also removed
"""
in_memory_cache = InMemoryCache()
## On object expiration, but no get_cache, the override should be allowed
in_memory_cache.set_cache(key="new-fake-key", value="new-fake-value", ttl=1)
initial_ttl_time = in_memory_cache.ttl_dict["new-fake-key"]
assert initial_ttl_time is not None
time.sleep(1)
in_memory_cache.set_cache(key="new-fake-key", value="new-fake-value-2", ttl=1)
new_ttl_time = in_memory_cache.ttl_dict["new-fake-key"]
assert new_ttl_time is not None
assert new_ttl_time != initial_ttl_time

View File

@@ -0,0 +1,411 @@
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
def test_qdrant_semantic_cache_initialization(monkeypatch):
"""
Test QDRANT semantic cache initialization with proper parameters.
Verifies that the cache is initialized correctly with given configuration.
"""
# Mock the httpx clients and API calls
with patch("litellm.llms.custom_httpx.http_handler._get_httpx_client") as mock_sync_client, \
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client") as mock_async_client:
# Mock the collection exists check
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"result": {"exists": True}}
mock_sync_client_instance = MagicMock()
mock_sync_client_instance.get.return_value = mock_response
mock_sync_client.return_value = mock_sync_client_instance
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
# Initialize the cache with similarity threshold
qdrant_cache = QdrantSemanticCache(
collection_name="test_collection",
qdrant_api_base="http://test.qdrant.local",
qdrant_api_key="test_key",
similarity_threshold=0.8,
)
# Verify the cache was initialized with correct parameters
assert qdrant_cache.collection_name == "test_collection"
assert qdrant_cache.qdrant_api_base == "http://test.qdrant.local"
assert qdrant_cache.qdrant_api_key == "test_key"
assert qdrant_cache.similarity_threshold == 0.8
# Test initialization with missing similarity_threshold
with pytest.raises(Exception, match="similarity_threshold must be provided"):
QdrantSemanticCache(
collection_name="test_collection",
qdrant_api_base="http://test.qdrant.local",
qdrant_api_key="test_key",
)
def test_qdrant_semantic_cache_get_cache_hit():
"""
Test QDRANT semantic cache get method when there's a cache hit.
Verifies that cached results are properly retrieved and parsed.
"""
with patch("litellm.llms.custom_httpx.http_handler._get_httpx_client") as mock_sync_client, \
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client") as mock_async_client:
# Mock the collection exists check
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"result": {"exists": True}}
mock_sync_client_instance = MagicMock()
mock_sync_client_instance.get.return_value = mock_response
mock_sync_client.return_value = mock_sync_client_instance
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
# Initialize cache
qdrant_cache = QdrantSemanticCache(
collection_name="test_collection",
qdrant_api_base="http://test.qdrant.local",
qdrant_api_key="test_key",
similarity_threshold=0.8,
)
# Mock a cache hit result from search API
mock_search_response = MagicMock()
mock_search_response.status_code = 200
mock_search_response.json.return_value = {
"result": [
{
"payload": {
"text": "What is the capital of France?", # Original prompt
"response": '{"id": "test-123", "choices": [{"message": {"content": "Paris is the capital of France."}}]}'
},
"score": 0.9
}
]
}
qdrant_cache.sync_client.post = MagicMock(return_value=mock_search_response)
# Mock the embedding function
with patch(
"litellm.embedding",
return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
):
# Test get_cache with a message
result = qdrant_cache.get_cache(
key="test_key",
messages=[{"content": "What is the capital of France?"}]
)
# Verify result is properly parsed
expected_result = {
"id": "test-123",
"choices": [{"message": {"content": "Paris is the capital of France."}}]
}
assert result == expected_result
# Verify search was called
qdrant_cache.sync_client.post.assert_called()
def test_qdrant_semantic_cache_get_cache_miss():
"""
Test QDRANT semantic cache get method when there's a cache miss.
Verifies that None is returned when no similar cached results are found.
"""
with patch("litellm.llms.custom_httpx.http_handler._get_httpx_client") as mock_sync_client, \
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client") as mock_async_client:
# Mock the collection exists check
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"result": {"exists": True}}
mock_sync_client_instance = MagicMock()
mock_sync_client_instance.get.return_value = mock_response
mock_sync_client.return_value = mock_sync_client_instance
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
# Initialize cache
qdrant_cache = QdrantSemanticCache(
collection_name="test_collection",
qdrant_api_base="http://test.qdrant.local",
qdrant_api_key="test_key",
similarity_threshold=0.8,
)
# Mock a cache miss (no results)
mock_search_response = MagicMock()
mock_search_response.status_code = 200
mock_search_response.json.return_value = {"result": []}
qdrant_cache.sync_client.post = MagicMock(return_value=mock_search_response)
# Mock the embedding function
with patch(
"litellm.embedding",
return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
):
# Test get_cache with a message
result = qdrant_cache.get_cache(
key="test_key",
messages=[{"content": "What is the capital of Spain?"}]
)
# Verify None is returned for cache miss
assert result is None
# Verify search was called
qdrant_cache.sync_client.post.assert_called()
@pytest.mark.asyncio
async def test_qdrant_semantic_cache_async_get_cache_hit():
"""
Test QDRANT semantic cache async get method when there's a cache hit.
Verifies that cached results are properly retrieved and parsed asynchronously.
"""
with patch("litellm.llms.custom_httpx.http_handler._get_httpx_client") as mock_sync_client, \
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client") as mock_async_client:
# Mock the collection exists check
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"result": {"exists": True}}
mock_sync_client_instance = MagicMock()
mock_sync_client_instance.get.return_value = mock_response
mock_sync_client.return_value = mock_sync_client_instance
# Mock async client
mock_async_client_instance = AsyncMock()
mock_async_client.return_value = mock_async_client_instance
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
# Initialize cache
qdrant_cache = QdrantSemanticCache(
collection_name="test_collection",
qdrant_api_base="http://test.qdrant.local",
qdrant_api_key="test_key",
similarity_threshold=0.8,
)
# Mock a cache hit result from async search API
# Note: .json() should be sync even for async responses
mock_search_response = MagicMock()
mock_search_response.status_code = 200
mock_search_response.json.return_value = {
"result": [
{
"payload": {
"text": "What is the capital of Spain?", # Original prompt
"response": '{"id": "test-456", "choices": [{"message": {"content": "Madrid is the capital of Spain."}}]}'
},
"score": 0.85
}
]
}
qdrant_cache.async_client.post = AsyncMock(return_value=mock_search_response)
# Mock the async embedding function
with patch(
"litellm.aembedding",
return_value={"data": [{"embedding": [0.4, 0.5, 0.6]}]}
):
# Test async_get_cache with a message
result = await qdrant_cache.async_get_cache(
key="test_key",
messages=[{"content": "What is the capital of Spain?"}],
metadata={},
)
# Verify result is properly parsed
expected_result = {
"id": "test-456",
"choices": [{"message": {"content": "Madrid is the capital of Spain."}}]
}
assert result == expected_result
# Verify async search was called
qdrant_cache.async_client.post.assert_called()
@pytest.mark.asyncio
async def test_qdrant_semantic_cache_async_get_cache_miss():
"""
Test QDRANT semantic cache async get method when there's a cache miss.
Verifies that None is returned when no similar cached results are found.
"""
with patch("litellm.llms.custom_httpx.http_handler._get_httpx_client") as mock_sync_client, \
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client") as mock_async_client:
# Mock the collection exists check
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"result": {"exists": True}}
mock_sync_client_instance = MagicMock()
mock_sync_client_instance.get.return_value = mock_response
mock_sync_client.return_value = mock_sync_client_instance
# Mock async client
mock_async_client_instance = AsyncMock()
mock_async_client.return_value = mock_async_client_instance
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
# Initialize cache
qdrant_cache = QdrantSemanticCache(
collection_name="test_collection",
qdrant_api_base="http://test.qdrant.local",
qdrant_api_key="test_key",
similarity_threshold=0.8,
)
# Mock a cache miss (no results)
mock_search_response = MagicMock() # Note: .json() should be sync
mock_search_response.status_code = 200
mock_search_response.json.return_value = {"result": []}
qdrant_cache.async_client.post = AsyncMock(return_value=mock_search_response)
# Mock the async embedding function
with patch(
"litellm.aembedding",
return_value={"data": [{"embedding": [0.7, 0.8, 0.9]}]}
):
# Test async_get_cache with a message
result = await qdrant_cache.async_get_cache(
key="test_key",
messages=[{"content": "What is the capital of Italy?"}],
metadata={},
)
# Verify None is returned for cache miss
assert result is None
# Verify async search was called
qdrant_cache.async_client.post.assert_called()
def test_qdrant_semantic_cache_set_cache():
"""
Test QDRANT semantic cache set method.
Verifies that responses are properly stored in the cache.
"""
with patch("litellm.llms.custom_httpx.http_handler._get_httpx_client") as mock_sync_client, \
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client") as mock_async_client:
# Mock the collection exists check
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"result": {"exists": True}}
mock_sync_client_instance = MagicMock()
mock_sync_client_instance.get.return_value = mock_response
mock_sync_client.return_value = mock_sync_client_instance
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
# Initialize cache
qdrant_cache = QdrantSemanticCache(
collection_name="test_collection",
qdrant_api_base="http://test.qdrant.local",
qdrant_api_key="test_key",
similarity_threshold=0.8,
)
# Mock the upsert method
mock_upsert_response = MagicMock()
mock_upsert_response.status_code = 200
qdrant_cache.sync_client.put = MagicMock(return_value=mock_upsert_response)
# Mock response to cache
response_to_cache = {
"id": "test-789",
"choices": [{"message": {"content": "Rome is the capital of Italy."}}]
}
# Mock the embedding function
with patch(
"litellm.embedding",
return_value={"data": [{"embedding": [0.1, 0.1, 0.1]}]}
):
# Test set_cache
qdrant_cache.set_cache(
key="test_key",
value=response_to_cache,
messages=[{"content": "What is the capital of Italy?"}]
)
# Verify upsert was called
qdrant_cache.sync_client.put.assert_called()
@pytest.mark.asyncio
async def test_qdrant_semantic_cache_async_set_cache():
"""
Test QDRANT semantic cache async set method.
Verifies that responses are properly stored in the cache asynchronously.
"""
with patch("litellm.llms.custom_httpx.http_handler._get_httpx_client") as mock_sync_client, \
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client") as mock_async_client:
# Mock the collection exists check
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"result": {"exists": True}}
mock_sync_client_instance = MagicMock()
mock_sync_client_instance.get.return_value = mock_response
mock_sync_client.return_value = mock_sync_client_instance
# Mock async client
mock_async_client_instance = AsyncMock()
mock_async_client.return_value = mock_async_client_instance
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
# Initialize cache
qdrant_cache = QdrantSemanticCache(
collection_name="test_collection",
qdrant_api_base="http://test.qdrant.local",
qdrant_api_key="test_key",
similarity_threshold=0.8,
)
# Mock the async upsert method
mock_upsert_response = MagicMock() # Note: .json() should be sync
mock_upsert_response.status_code = 200
qdrant_cache.async_client.put = AsyncMock(return_value=mock_upsert_response)
# Mock response to cache
response_to_cache = {
"id": "test-999",
"choices": [{"message": {"content": "Berlin is the capital of Germany."}}]
}
# Mock the async embedding function
with patch(
"litellm.aembedding",
return_value={"data": [{"embedding": [0.2, 0.2, 0.2]}]}
):
# Test async_set_cache
await qdrant_cache.async_set_cache(
key="test_key",
value=response_to_cache,
messages=[{"content": "What is the capital of Germany?"}],
metadata={}
)
# Verify async upsert was called
qdrant_cache.async_client.put.assert_called()

View File

@@ -0,0 +1,122 @@
import os
import sys
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from unittest.mock import AsyncMock
from litellm.caching.redis_cache import RedisCache
@pytest.fixture
def redis_no_ping():
"""Patch RedisCache initialization to prevent async ping tasks from being created"""
with patch("asyncio.get_running_loop") as mock_get_loop:
# Either raise an exception or return a mock that will handle the task creation
mock_get_loop.side_effect = RuntimeError("No running event loop")
yield
@pytest.mark.parametrize("namespace", [None, "test"])
@pytest.mark.asyncio
async def test_redis_cache_async_increment(namespace, monkeypatch, redis_no_ping):
monkeypatch.setenv("REDIS_HOST", "https://my-test-host")
redis_cache = RedisCache(namespace=namespace)
# Create an AsyncMock for the Redis client
mock_redis_instance = AsyncMock()
# Make sure the mock can be used as an async context manager
mock_redis_instance.__aenter__.return_value = mock_redis_instance
mock_redis_instance.__aexit__.return_value = None
assert redis_cache is not None
expected_key = "test:test" if namespace else "test"
with patch.object(
redis_cache, "init_async_client", return_value=mock_redis_instance
):
# Call async_set_cache
await redis_cache.async_increment(key=expected_key, value=1)
# Verify that the set method was called on the mock Redis instance
mock_redis_instance.incrbyfloat.assert_called_once_with(
name=expected_key, amount=1
)
@pytest.mark.asyncio
async def test_redis_client_init_with_socket_timeout(monkeypatch, redis_no_ping):
monkeypatch.setenv("REDIS_HOST", "my-fake-host")
redis_cache = RedisCache(socket_timeout=1.0)
assert redis_cache.redis_kwargs["socket_timeout"] == 1.0
client = redis_cache.init_async_client()
assert client is not None
assert client.connection_pool.connection_kwargs["socket_timeout"] == 1.0
@pytest.mark.asyncio
async def test_redis_cache_async_batch_get_cache(monkeypatch, redis_no_ping):
monkeypatch.setenv("REDIS_HOST", "https://my-test-host")
redis_cache = RedisCache()
# Create an AsyncMock for the Redis client
mock_redis_instance = AsyncMock()
# Make sure the mock can be used as an async context manager
mock_redis_instance.__aenter__.return_value = mock_redis_instance
mock_redis_instance.__aexit__.return_value = None
# Setup the return value for mget
mock_redis_instance.mget.return_value = [
b'{"key1": "value1"}',
None,
b'{"key3": "value3"}',
]
test_keys = ["key1", "key2", "key3"]
with patch.object(
redis_cache, "init_async_client", return_value=mock_redis_instance
):
# Call async_batch_get_cache
result = await redis_cache.async_batch_get_cache(key_list=test_keys)
# Verify mget was called with the correct keys
mock_redis_instance.mget.assert_called_once()
# Check that results were properly decoded
assert result["key1"] == {"key1": "value1"}
assert result["key2"] is None
assert result["key3"] == {"key3": "value3"}
@pytest.mark.asyncio
async def test_handle_lpop_count_for_older_redis_versions(monkeypatch):
"""Test the helper method that handles LPOP with count for Redis versions < 7.0"""
monkeypatch.setenv("REDIS_HOST", "https://my-test-host")
# Create RedisCache instance
redis_cache = RedisCache()
# Create a mock pipeline
mock_pipeline = AsyncMock()
# Set up execute to return different values each time
mock_pipeline.execute.side_effect = [
[b"value1"], # First execute returns first value
[b"value2"], # Second execute returns second value
]
# Test the helper method
result = await redis_cache.handle_lpop_count_for_older_redis_versions(
pipe=mock_pipeline, key="test_key", count=2
)
# Verify results
assert result == [b"value1", b"value2"]
assert mock_pipeline.lpop.call_count == 2
assert mock_pipeline.execute.call_count == 2

View File

@@ -0,0 +1,66 @@
import json
import os
import sys
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm.caching.redis_cluster_cache import RedisClusterCache
@patch("litellm._redis.init_redis_cluster")
def test_redis_cluster_batch_get(mock_init_redis_cluster):
"""
Test that RedisClusterCache uses mget_nonatomic instead of mget for batch operations
"""
# Create a mock Redis client
mock_redis = MagicMock()
mock_redis.mget_nonatomic.return_value = [None, None] # Simulate no cache hits
mock_init_redis_cluster.return_value = mock_redis
# Create RedisClusterCache instance with mock client
cache = RedisClusterCache(
startup_nodes=[{"host": "localhost", "port": 6379}],
password="hello",
)
# Test batch_get_cache
keys = ["key1", "key2"]
cache.batch_get_cache(keys)
# Verify mget_nonatomic was called instead of mget
mock_redis.mget_nonatomic.assert_called_once()
assert not mock_redis.mget.called
@pytest.mark.asyncio
@patch("litellm._redis.init_redis_cluster")
async def test_redis_cluster_async_batch_get(mock_init_redis_cluster):
"""
Test that RedisClusterCache uses mget_nonatomic instead of mget for async batch operations
"""
# Create a mock Redis client
mock_redis = MagicMock()
mock_redis.mget_nonatomic.return_value = [None, None] # Simulate no cache hits
# Create RedisClusterCache instance with mock client
cache = RedisClusterCache(
startup_nodes=[{"host": "localhost", "port": 6379}],
password="hello",
)
# Mock the init_async_client to return our mock redis client
cache.init_async_client = MagicMock(return_value=mock_redis)
# Test async_batch_get_cache
keys = ["key1", "key2"]
await cache.async_batch_get_cache(keys)
# Verify mget_nonatomic was called instead of mget
mock_redis.mget_nonatomic.assert_called_once()
assert not mock_redis.mget.called

View File

@@ -0,0 +1,146 @@
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
# Tests for RedisSemanticCache
def test_redis_semantic_cache_initialization(monkeypatch):
# Mock the redisvl import
semantic_cache_mock = MagicMock()
with patch.dict(
"sys.modules",
{
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(CustomTextVectorizer=MagicMock()),
},
):
from litellm.caching.redis_semantic_cache import RedisSemanticCache
# Set environment variables
monkeypatch.setenv("REDIS_HOST", "localhost")
monkeypatch.setenv("REDIS_PORT", "6379")
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
# Initialize the cache with a similarity threshold
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
# Verify the semantic cache was initialized with correct parameters
assert redis_semantic_cache.similarity_threshold == 0.8
# Use pytest.approx for floating point comparison to handle precision issues
assert redis_semantic_cache.distance_threshold == pytest.approx(0.2, abs=1e-10)
assert redis_semantic_cache.embedding_model == "text-embedding-ada-002"
# Test initialization with missing similarity_threshold
with pytest.raises(ValueError, match="similarity_threshold must be provided"):
RedisSemanticCache()
def test_redis_semantic_cache_get_cache(monkeypatch):
# Mock the redisvl import and embedding function
semantic_cache_mock = MagicMock()
custom_vectorizer_mock = MagicMock()
with patch.dict(
"sys.modules",
{
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(
CustomTextVectorizer=custom_vectorizer_mock
),
},
):
from litellm.caching.redis_semantic_cache import RedisSemanticCache
# Set environment variables
monkeypatch.setenv("REDIS_HOST", "localhost")
monkeypatch.setenv("REDIS_PORT", "6379")
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
# Initialize cache
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
# Mock the llmcache.check method to return a result
mock_result = [
{
"prompt": "What is the capital of France?",
"response": '{"content": "Paris is the capital of France."}',
"vector_distance": 0.1, # Distance of 0.1 means similarity of 0.9
}
]
redis_semantic_cache.llmcache.check = MagicMock(return_value=mock_result)
# Mock the embedding function
with patch(
"litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
):
# Test get_cache with a message
result = redis_semantic_cache.get_cache(
key="test_key", messages=[{"content": "What is the capital of France?"}]
)
# Verify result is properly parsed
assert result == {"content": "Paris is the capital of France."}
# Verify llmcache.check was called
redis_semantic_cache.llmcache.check.assert_called_once()
@pytest.mark.asyncio
async def test_redis_semantic_cache_async_get_cache(monkeypatch):
# Mock the redisvl import
semantic_cache_mock = MagicMock()
custom_vectorizer_mock = MagicMock()
with patch.dict(
"sys.modules",
{
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(
CustomTextVectorizer=custom_vectorizer_mock
),
},
):
from litellm.caching.redis_semantic_cache import RedisSemanticCache
# Set environment variables
monkeypatch.setenv("REDIS_HOST", "localhost")
monkeypatch.setenv("REDIS_PORT", "6379")
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
# Initialize cache
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
# Mock the async methods
mock_result = [
{
"prompt": "What is the capital of France?",
"response": '{"content": "Paris is the capital of France."}',
"vector_distance": 0.1, # Distance of 0.1 means similarity of 0.9
}
]
redis_semantic_cache.llmcache.acheck = AsyncMock(return_value=mock_result)
redis_semantic_cache._get_async_embedding = AsyncMock(
return_value=[0.1, 0.2, 0.3]
)
# Test async_get_cache with a message
result = await redis_semantic_cache.async_get_cache(
key="test_key",
messages=[{"content": "What is the capital of France?"}],
metadata={},
)
# Verify result is properly parsed
assert result == {"content": "Paris is the capital of France."}
# Verify methods were called
redis_semantic_cache._get_async_embedding.assert_called_once()
redis_semantic_cache.llmcache.acheck.assert_called_once()