Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,198 @@
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.caching.azure_blob_cache import AzureBlobCache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_dependencies():
|
||||
"""Mock all Azure dependencies to avoid requiring actual Azure credentials"""
|
||||
|
||||
# Create mock container clients that will be assigned to the cache instance
|
||||
mock_container_client = MagicMock()
|
||||
mock_async_container_client = AsyncMock()
|
||||
|
||||
# Mock credentials
|
||||
mock_credential = MagicMock()
|
||||
mock_async_credential = AsyncMock()
|
||||
|
||||
# Create mock blob service clients that return the container clients
|
||||
mock_blob_service_client = MagicMock()
|
||||
mock_blob_service_client.get_container_client.return_value = mock_container_client
|
||||
|
||||
mock_async_blob_service_client = AsyncMock()
|
||||
# For AsyncMock, we need to make get_container_client return the mock directly, not a coroutine
|
||||
mock_async_blob_service_client.get_container_client = MagicMock(return_value=mock_async_container_client)
|
||||
|
||||
# Patch Azure dependencies at their source locations
|
||||
with patch("azure.identity.DefaultAzureCredential", return_value=mock_credential), \
|
||||
patch("azure.identity.aio.DefaultAzureCredential", return_value=mock_async_credential), \
|
||||
patch("azure.storage.blob.BlobServiceClient", return_value=mock_blob_service_client), \
|
||||
patch("azure.storage.blob.aio.BlobServiceClient", return_value=mock_async_blob_service_client), \
|
||||
patch("azure.core.exceptions.ResourceExistsError"):
|
||||
|
||||
yield {
|
||||
"container_client": mock_container_client,
|
||||
"async_container_client": mock_async_container_client,
|
||||
"blob_service_client": mock_blob_service_client,
|
||||
"async_blob_service_client": mock_async_blob_service_client,
|
||||
"credential": mock_credential,
|
||||
"async_credential": mock_async_credential,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blob_cache_async_get_cache(mock_azure_dependencies):
|
||||
"""Test async_get_cache method with mocked Azure dependencies"""
|
||||
|
||||
# Create cache instance (this will use the mocked dependencies)
|
||||
cache = AzureBlobCache("https://my-test-host", "test-container")
|
||||
|
||||
# Mock the download_blob response
|
||||
mock_blob = AsyncMock()
|
||||
mock_blob.readall.return_value = b'{"test_key": "test_value"}'
|
||||
|
||||
# Set up the mock for download_blob on the actual container client instance
|
||||
cache.async_container_client.download_blob.return_value = mock_blob
|
||||
|
||||
# Test successful cache retrieval
|
||||
result = await cache.async_get_cache("test_key")
|
||||
|
||||
# Verify the call was made correctly
|
||||
cache.async_container_client.download_blob.assert_called_once_with("test_key")
|
||||
mock_blob.readall.assert_called_once()
|
||||
|
||||
# Check the result
|
||||
assert result == {"test_key": "test_value"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blob_cache_async_get_cache_not_found(mock_azure_dependencies):
|
||||
"""Test async_get_cache method when blob is not found"""
|
||||
|
||||
# Import the exception inside the test to avoid import issues
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
|
||||
cache = AzureBlobCache("https://my-test-host", "test-container")
|
||||
|
||||
# Mock ResourceNotFoundError
|
||||
cache.async_container_client.download_blob.side_effect = ResourceNotFoundError("Blob not found")
|
||||
|
||||
# Test cache miss
|
||||
result = await cache.async_get_cache("nonexistent_key")
|
||||
|
||||
# Verify the call was made and result is None
|
||||
cache.async_container_client.download_blob.assert_called_once_with("nonexistent_key")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blob_cache_async_set_cache(mock_azure_dependencies):
|
||||
"""Test async_set_cache method with mocked Azure dependencies"""
|
||||
|
||||
cache = AzureBlobCache("https://my-test-host", "test-container")
|
||||
|
||||
test_value = {"key": "value", "number": 42}
|
||||
|
||||
# Test setting cache
|
||||
await cache.async_set_cache("test_key", test_value)
|
||||
|
||||
# Verify the call was made correctly
|
||||
cache.async_container_client.upload_blob.assert_called_once_with(
|
||||
"test_key",
|
||||
'{"key": "value", "number": 42}',
|
||||
overwrite=True
|
||||
)
|
||||
|
||||
|
||||
def test_blob_cache_sync_get_cache(mock_azure_dependencies):
|
||||
"""Test sync get_cache method with mocked Azure dependencies"""
|
||||
|
||||
cache = AzureBlobCache("https://my-test-host", "test-container")
|
||||
|
||||
# Mock the download_blob response
|
||||
mock_blob = MagicMock()
|
||||
mock_blob.readall.return_value = b'{"sync_key": "sync_value"}'
|
||||
|
||||
cache.container_client.download_blob.return_value = mock_blob
|
||||
|
||||
# Test successful cache retrieval
|
||||
result = cache.get_cache("sync_key")
|
||||
|
||||
# Verify the call was made correctly
|
||||
cache.container_client.download_blob.assert_called_once_with("sync_key")
|
||||
mock_blob.readall.assert_called_once()
|
||||
|
||||
# Check the result
|
||||
assert result == {"sync_key": "sync_value"}
|
||||
|
||||
|
||||
def test_blob_cache_sync_set_cache(mock_azure_dependencies):
|
||||
"""Test sync set_cache method with mocked Azure dependencies"""
|
||||
|
||||
cache = AzureBlobCache("https://my-test-host", "test-container")
|
||||
|
||||
test_value = {"sync_key": "sync_value", "number": 123}
|
||||
|
||||
# Test setting cache
|
||||
cache.set_cache("sync_test_key", test_value)
|
||||
|
||||
# Verify the call was made correctly
|
||||
cache.container_client.upload_blob.assert_called_once_with(
|
||||
"sync_test_key",
|
||||
'{"sync_key": "sync_value", "number": 123}'
|
||||
)
|
||||
|
||||
|
||||
def test_blob_cache_sync_get_cache_not_found(mock_azure_dependencies):
|
||||
"""Test sync get_cache method when blob is not found"""
|
||||
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
|
||||
cache = AzureBlobCache("https://my-test-host", "test-container")
|
||||
|
||||
# Mock ResourceNotFoundError
|
||||
cache.container_client.download_blob.side_effect = ResourceNotFoundError("Blob not found")
|
||||
|
||||
# Test cache miss
|
||||
result = cache.get_cache("nonexistent_key")
|
||||
|
||||
# Verify the call was made and result is None
|
||||
cache.container_client.download_blob.assert_called_once_with("nonexistent_key")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blob_cache_async_set_cache_pipeline(mock_azure_dependencies):
|
||||
"""Test async_set_cache_pipeline method with mocked Azure dependencies"""
|
||||
|
||||
cache = AzureBlobCache("https://my-test-host", "test-container")
|
||||
|
||||
# Test data for pipeline
|
||||
cache_list = [
|
||||
("key1", {"value": "data1"}),
|
||||
("key2", {"value": "data2"}),
|
||||
("key3", {"value": "data3"}),
|
||||
]
|
||||
|
||||
# Test pipeline cache setting
|
||||
await cache.async_set_cache_pipeline(cache_list)
|
||||
|
||||
# Verify all calls were made correctly
|
||||
expected_calls = [
|
||||
(("key1", '{"value": "data1"}'), {"overwrite": True}),
|
||||
(("key2", '{"value": "data2"}'), {"overwrite": True}),
|
||||
(("key3", '{"value": "data3"}'), {"overwrite": True}),
|
||||
]
|
||||
|
||||
assert cache.async_container_client.upload_blob.call_count == 3
|
||||
for expected_call in expected_calls:
|
||||
cache.async_container_client.upload_blob.assert_any_call(*expected_call[0], **expected_call[1])
|
@@ -0,0 +1,54 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from litellm.caching.caching_handler import LLMCachingHandler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_async_embedding_cached_response():
|
||||
llm_caching_handler = LLMCachingHandler(
|
||||
original_function=MagicMock(),
|
||||
request_kwargs={},
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
args = {
|
||||
"cached_result": [
|
||||
{
|
||||
"embedding": [-0.025122925639152527, -0.019487135112285614],
|
||||
"index": 0,
|
||||
"object": "embedding",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
mock_logging_obj = MagicMock()
|
||||
mock_logging_obj.async_success_handler = AsyncMock()
|
||||
response, cache_hit = llm_caching_handler._process_async_embedding_cached_response(
|
||||
final_embedding_cached_response=None,
|
||||
cached_result=args["cached_result"],
|
||||
kwargs={"model": "text-embedding-ada-002", "input": "test"},
|
||||
logging_obj=mock_logging_obj,
|
||||
start_time=datetime.now(),
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
|
||||
assert cache_hit
|
||||
|
||||
print(f"response: {response}")
|
||||
assert len(response.data) == 1
|
@@ -0,0 +1,36 @@
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../../.."))
|
||||
|
||||
from litellm.caching.gcs_cache import GCSCache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gcs_dependencies():
|
||||
"""Mock httpx clients and GCS auth"""
|
||||
mock_sync_client = MagicMock()
|
||||
mock_async_client = AsyncMock()
|
||||
|
||||
with patch("litellm.caching.gcs_cache._get_httpx_client", return_value=mock_sync_client), \
|
||||
patch("litellm.caching.gcs_cache.get_async_httpx_client", return_value=mock_async_client), \
|
||||
patch("litellm.caching.gcs_cache.GCSBucketBase.sync_construct_request_headers", return_value={}):
|
||||
yield {
|
||||
"sync_client": mock_sync_client,
|
||||
"async_client": mock_async_client,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_cache_async_set_and_get(mock_gcs_dependencies):
|
||||
cache = GCSCache(bucket_name="test-bucket")
|
||||
await cache.async_set_cache("key", {"foo": "bar"})
|
||||
mock_gcs_dependencies["async_client"].post.assert_called_once()
|
||||
|
||||
mock_gcs_dependencies["async_client"].get.return_value.status_code = 200
|
||||
mock_gcs_dependencies["async_client"].get.return_value.text = "{\"foo\": \"bar\"}"
|
||||
result = await cache.async_get_cache("key")
|
||||
assert result == {"foo": "bar"}
|
@@ -0,0 +1,90 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
|
||||
|
||||
def test_in_memory_openai_obj_cache():
|
||||
from openai import OpenAI
|
||||
|
||||
openai_obj = OpenAI(api_key="my-fake-key")
|
||||
|
||||
in_memory_cache = InMemoryCache()
|
||||
|
||||
in_memory_cache.set_cache(key="my-fake-key", value=openai_obj)
|
||||
|
||||
cached_obj = in_memory_cache.get_cache(key="my-fake-key")
|
||||
|
||||
assert cached_obj is not None
|
||||
|
||||
assert cached_obj == openai_obj
|
||||
|
||||
|
||||
def test_in_memory_cache_max_size_per_item():
|
||||
"""
|
||||
Test that the cache will not store items larger than the max size per item
|
||||
"""
|
||||
in_memory_cache = InMemoryCache(max_size_per_item=100)
|
||||
|
||||
result = in_memory_cache.check_value_size("a" * 100000000)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_in_memory_cache_ttl():
|
||||
"""
|
||||
Check that
|
||||
- if ttl is not set, it will be set to default ttl
|
||||
- if object expires, the ttl is also removed
|
||||
"""
|
||||
in_memory_cache = InMemoryCache()
|
||||
|
||||
in_memory_cache.set_cache(key="my-fake-key", value="my-fake-value", ttl=10)
|
||||
initial_ttl_time = in_memory_cache.ttl_dict["my-fake-key"]
|
||||
assert initial_ttl_time is not None
|
||||
|
||||
in_memory_cache.set_cache(key="my-fake-key", value="my-fake-value-2", ttl=10)
|
||||
new_ttl_time = in_memory_cache.ttl_dict["my-fake-key"]
|
||||
assert new_ttl_time == initial_ttl_time # ttl should not be updated
|
||||
|
||||
## On object expiration, the ttl should be removed
|
||||
in_memory_cache.set_cache(key="new-fake-key", value="new-fake-value", ttl=1)
|
||||
new_ttl_time = in_memory_cache.ttl_dict["new-fake-key"]
|
||||
assert new_ttl_time is not None
|
||||
time.sleep(1)
|
||||
cached_obj = in_memory_cache.get_cache(key="new-fake-key")
|
||||
new_ttl_time = in_memory_cache.ttl_dict.get("new-fake-key")
|
||||
assert new_ttl_time is None
|
||||
|
||||
|
||||
def test_in_memory_cache_ttl_allow_override():
|
||||
"""
|
||||
Check that
|
||||
- if ttl is not set, it will be set to default ttl
|
||||
- if object expires, the ttl is also removed
|
||||
"""
|
||||
in_memory_cache = InMemoryCache()
|
||||
## On object expiration, but no get_cache, the override should be allowed
|
||||
in_memory_cache.set_cache(key="new-fake-key", value="new-fake-value", ttl=1)
|
||||
initial_ttl_time = in_memory_cache.ttl_dict["new-fake-key"]
|
||||
assert initial_ttl_time is not None
|
||||
time.sleep(1)
|
||||
|
||||
in_memory_cache.set_cache(key="new-fake-key", value="new-fake-value-2", ttl=1)
|
||||
new_ttl_time = in_memory_cache.ttl_dict["new-fake-key"]
|
||||
assert new_ttl_time is not None
|
||||
assert new_ttl_time != initial_ttl_time
|
@@ -0,0 +1,411 @@
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
def test_qdrant_semantic_cache_initialization(monkeypatch):
|
||||
"""
|
||||
Test QDRANT semantic cache initialization with proper parameters.
|
||||
Verifies that the cache is initialized correctly with given configuration.
|
||||
"""
|
||||
# Mock the httpx clients and API calls
|
||||
with patch("litellm.llms.custom_httpx.http_handler._get_httpx_client") as mock_sync_client, \
|
||||
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client") as mock_async_client:
|
||||
|
||||
# Mock the collection exists check
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"result": {"exists": True}}
|
||||
|
||||
mock_sync_client_instance = MagicMock()
|
||||
mock_sync_client_instance.get.return_value = mock_response
|
||||
mock_sync_client.return_value = mock_sync_client_instance
|
||||
|
||||
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
|
||||
|
||||
# Initialize the cache with similarity threshold
|
||||
qdrant_cache = QdrantSemanticCache(
|
||||
collection_name="test_collection",
|
||||
qdrant_api_base="http://test.qdrant.local",
|
||||
qdrant_api_key="test_key",
|
||||
similarity_threshold=0.8,
|
||||
)
|
||||
|
||||
# Verify the cache was initialized with correct parameters
|
||||
assert qdrant_cache.collection_name == "test_collection"
|
||||
assert qdrant_cache.qdrant_api_base == "http://test.qdrant.local"
|
||||
assert qdrant_cache.qdrant_api_key == "test_key"
|
||||
assert qdrant_cache.similarity_threshold == 0.8
|
||||
|
||||
# Test initialization with missing similarity_threshold
|
||||
with pytest.raises(Exception, match="similarity_threshold must be provided"):
|
||||
QdrantSemanticCache(
|
||||
collection_name="test_collection",
|
||||
qdrant_api_base="http://test.qdrant.local",
|
||||
qdrant_api_key="test_key",
|
||||
)
|
||||
|
||||
|
||||
def test_qdrant_semantic_cache_get_cache_hit():
|
||||
"""
|
||||
Test QDRANT semantic cache get method when there's a cache hit.
|
||||
Verifies that cached results are properly retrieved and parsed.
|
||||
"""
|
||||
with patch("litellm.llms.custom_httpx.http_handler._get_httpx_client") as mock_sync_client, \
|
||||
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client") as mock_async_client:
|
||||
|
||||
# Mock the collection exists check
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"result": {"exists": True}}
|
||||
|
||||
mock_sync_client_instance = MagicMock()
|
||||
mock_sync_client_instance.get.return_value = mock_response
|
||||
mock_sync_client.return_value = mock_sync_client_instance
|
||||
|
||||
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
|
||||
|
||||
# Initialize cache
|
||||
qdrant_cache = QdrantSemanticCache(
|
||||
collection_name="test_collection",
|
||||
qdrant_api_base="http://test.qdrant.local",
|
||||
qdrant_api_key="test_key",
|
||||
similarity_threshold=0.8,
|
||||
)
|
||||
|
||||
# Mock a cache hit result from search API
|
||||
mock_search_response = MagicMock()
|
||||
mock_search_response.status_code = 200
|
||||
mock_search_response.json.return_value = {
|
||||
"result": [
|
||||
{
|
||||
"payload": {
|
||||
"text": "What is the capital of France?", # Original prompt
|
||||
"response": '{"id": "test-123", "choices": [{"message": {"content": "Paris is the capital of France."}}]}'
|
||||
},
|
||||
"score": 0.9
|
||||
}
|
||||
]
|
||||
}
|
||||
qdrant_cache.sync_client.post = MagicMock(return_value=mock_search_response)
|
||||
|
||||
# Mock the embedding function
|
||||
with patch(
|
||||
"litellm.embedding",
|
||||
return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
):
|
||||
# Test get_cache with a message
|
||||
result = qdrant_cache.get_cache(
|
||||
key="test_key",
|
||||
messages=[{"content": "What is the capital of France?"}]
|
||||
)
|
||||
|
||||
# Verify result is properly parsed
|
||||
expected_result = {
|
||||
"id": "test-123",
|
||||
"choices": [{"message": {"content": "Paris is the capital of France."}}]
|
||||
}
|
||||
assert result == expected_result
|
||||
|
||||
# Verify search was called
|
||||
qdrant_cache.sync_client.post.assert_called()
|
||||
|
||||
|
||||
def test_qdrant_semantic_cache_get_cache_miss():
|
||||
"""
|
||||
Test QDRANT semantic cache get method when there's a cache miss.
|
||||
Verifies that None is returned when no similar cached results are found.
|
||||
"""
|
||||
with patch("litellm.llms.custom_httpx.http_handler._get_httpx_client") as mock_sync_client, \
|
||||
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client") as mock_async_client:
|
||||
|
||||
# Mock the collection exists check
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"result": {"exists": True}}
|
||||
|
||||
mock_sync_client_instance = MagicMock()
|
||||
mock_sync_client_instance.get.return_value = mock_response
|
||||
mock_sync_client.return_value = mock_sync_client_instance
|
||||
|
||||
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
|
||||
|
||||
# Initialize cache
|
||||
qdrant_cache = QdrantSemanticCache(
|
||||
collection_name="test_collection",
|
||||
qdrant_api_base="http://test.qdrant.local",
|
||||
qdrant_api_key="test_key",
|
||||
similarity_threshold=0.8,
|
||||
)
|
||||
|
||||
# Mock a cache miss (no results)
|
||||
mock_search_response = MagicMock()
|
||||
mock_search_response.status_code = 200
|
||||
mock_search_response.json.return_value = {"result": []}
|
||||
qdrant_cache.sync_client.post = MagicMock(return_value=mock_search_response)
|
||||
|
||||
# Mock the embedding function
|
||||
with patch(
|
||||
"litellm.embedding",
|
||||
return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
):
|
||||
# Test get_cache with a message
|
||||
result = qdrant_cache.get_cache(
|
||||
key="test_key",
|
||||
messages=[{"content": "What is the capital of Spain?"}]
|
||||
)
|
||||
|
||||
# Verify None is returned for cache miss
|
||||
assert result is None
|
||||
|
||||
# Verify search was called
|
||||
qdrant_cache.sync_client.post.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qdrant_semantic_cache_async_get_cache_hit():
|
||||
"""
|
||||
Test QDRANT semantic cache async get method when there's a cache hit.
|
||||
Verifies that cached results are properly retrieved and parsed asynchronously.
|
||||
"""
|
||||
with patch("litellm.llms.custom_httpx.http_handler._get_httpx_client") as mock_sync_client, \
|
||||
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client") as mock_async_client:
|
||||
|
||||
# Mock the collection exists check
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"result": {"exists": True}}
|
||||
|
||||
mock_sync_client_instance = MagicMock()
|
||||
mock_sync_client_instance.get.return_value = mock_response
|
||||
mock_sync_client.return_value = mock_sync_client_instance
|
||||
|
||||
# Mock async client
|
||||
mock_async_client_instance = AsyncMock()
|
||||
mock_async_client.return_value = mock_async_client_instance
|
||||
|
||||
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
|
||||
|
||||
# Initialize cache
|
||||
qdrant_cache = QdrantSemanticCache(
|
||||
collection_name="test_collection",
|
||||
qdrant_api_base="http://test.qdrant.local",
|
||||
qdrant_api_key="test_key",
|
||||
similarity_threshold=0.8,
|
||||
)
|
||||
|
||||
# Mock a cache hit result from async search API
|
||||
# Note: .json() should be sync even for async responses
|
||||
mock_search_response = MagicMock()
|
||||
mock_search_response.status_code = 200
|
||||
mock_search_response.json.return_value = {
|
||||
"result": [
|
||||
{
|
||||
"payload": {
|
||||
"text": "What is the capital of Spain?", # Original prompt
|
||||
"response": '{"id": "test-456", "choices": [{"message": {"content": "Madrid is the capital of Spain."}}]}'
|
||||
},
|
||||
"score": 0.85
|
||||
}
|
||||
]
|
||||
}
|
||||
qdrant_cache.async_client.post = AsyncMock(return_value=mock_search_response)
|
||||
|
||||
# Mock the async embedding function
|
||||
with patch(
|
||||
"litellm.aembedding",
|
||||
return_value={"data": [{"embedding": [0.4, 0.5, 0.6]}]}
|
||||
):
|
||||
# Test async_get_cache with a message
|
||||
result = await qdrant_cache.async_get_cache(
|
||||
key="test_key",
|
||||
messages=[{"content": "What is the capital of Spain?"}],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
# Verify result is properly parsed
|
||||
expected_result = {
|
||||
"id": "test-456",
|
||||
"choices": [{"message": {"content": "Madrid is the capital of Spain."}}]
|
||||
}
|
||||
assert result == expected_result
|
||||
|
||||
# Verify async search was called
|
||||
qdrant_cache.async_client.post.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qdrant_semantic_cache_async_get_cache_miss():
|
||||
"""
|
||||
Test QDRANT semantic cache async get method when there's a cache miss.
|
||||
Verifies that None is returned when no similar cached results are found.
|
||||
"""
|
||||
with patch("litellm.llms.custom_httpx.http_handler._get_httpx_client") as mock_sync_client, \
|
||||
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client") as mock_async_client:
|
||||
|
||||
# Mock the collection exists check
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"result": {"exists": True}}
|
||||
|
||||
mock_sync_client_instance = MagicMock()
|
||||
mock_sync_client_instance.get.return_value = mock_response
|
||||
mock_sync_client.return_value = mock_sync_client_instance
|
||||
|
||||
# Mock async client
|
||||
mock_async_client_instance = AsyncMock()
|
||||
mock_async_client.return_value = mock_async_client_instance
|
||||
|
||||
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
|
||||
|
||||
# Initialize cache
|
||||
qdrant_cache = QdrantSemanticCache(
|
||||
collection_name="test_collection",
|
||||
qdrant_api_base="http://test.qdrant.local",
|
||||
qdrant_api_key="test_key",
|
||||
similarity_threshold=0.8,
|
||||
)
|
||||
|
||||
# Mock a cache miss (no results)
|
||||
mock_search_response = MagicMock() # Note: .json() should be sync
|
||||
mock_search_response.status_code = 200
|
||||
mock_search_response.json.return_value = {"result": []}
|
||||
qdrant_cache.async_client.post = AsyncMock(return_value=mock_search_response)
|
||||
|
||||
# Mock the async embedding function
|
||||
with patch(
|
||||
"litellm.aembedding",
|
||||
return_value={"data": [{"embedding": [0.7, 0.8, 0.9]}]}
|
||||
):
|
||||
# Test async_get_cache with a message
|
||||
result = await qdrant_cache.async_get_cache(
|
||||
key="test_key",
|
||||
messages=[{"content": "What is the capital of Italy?"}],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
# Verify None is returned for cache miss
|
||||
assert result is None
|
||||
|
||||
# Verify async search was called
|
||||
qdrant_cache.async_client.post.assert_called()
|
||||
|
||||
|
||||
def test_qdrant_semantic_cache_set_cache():
|
||||
"""
|
||||
Test QDRANT semantic cache set method.
|
||||
Verifies that responses are properly stored in the cache.
|
||||
"""
|
||||
with patch("litellm.llms.custom_httpx.http_handler._get_httpx_client") as mock_sync_client, \
|
||||
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client") as mock_async_client:
|
||||
|
||||
# Mock the collection exists check
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"result": {"exists": True}}
|
||||
|
||||
mock_sync_client_instance = MagicMock()
|
||||
mock_sync_client_instance.get.return_value = mock_response
|
||||
mock_sync_client.return_value = mock_sync_client_instance
|
||||
|
||||
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
|
||||
|
||||
# Initialize cache
|
||||
qdrant_cache = QdrantSemanticCache(
|
||||
collection_name="test_collection",
|
||||
qdrant_api_base="http://test.qdrant.local",
|
||||
qdrant_api_key="test_key",
|
||||
similarity_threshold=0.8,
|
||||
)
|
||||
|
||||
# Mock the upsert method
|
||||
mock_upsert_response = MagicMock()
|
||||
mock_upsert_response.status_code = 200
|
||||
qdrant_cache.sync_client.put = MagicMock(return_value=mock_upsert_response)
|
||||
|
||||
# Mock response to cache
|
||||
response_to_cache = {
|
||||
"id": "test-789",
|
||||
"choices": [{"message": {"content": "Rome is the capital of Italy."}}]
|
||||
}
|
||||
|
||||
# Mock the embedding function
|
||||
with patch(
|
||||
"litellm.embedding",
|
||||
return_value={"data": [{"embedding": [0.1, 0.1, 0.1]}]}
|
||||
):
|
||||
# Test set_cache
|
||||
qdrant_cache.set_cache(
|
||||
key="test_key",
|
||||
value=response_to_cache,
|
||||
messages=[{"content": "What is the capital of Italy?"}]
|
||||
)
|
||||
|
||||
# Verify upsert was called
|
||||
qdrant_cache.sync_client.put.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qdrant_semantic_cache_async_set_cache():
|
||||
"""
|
||||
Test QDRANT semantic cache async set method.
|
||||
Verifies that responses are properly stored in the cache asynchronously.
|
||||
"""
|
||||
with patch("litellm.llms.custom_httpx.http_handler._get_httpx_client") as mock_sync_client, \
|
||||
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client") as mock_async_client:
|
||||
|
||||
# Mock the collection exists check
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"result": {"exists": True}}
|
||||
|
||||
mock_sync_client_instance = MagicMock()
|
||||
mock_sync_client_instance.get.return_value = mock_response
|
||||
mock_sync_client.return_value = mock_sync_client_instance
|
||||
|
||||
# Mock async client
|
||||
mock_async_client_instance = AsyncMock()
|
||||
mock_async_client.return_value = mock_async_client_instance
|
||||
|
||||
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
|
||||
|
||||
# Initialize cache
|
||||
qdrant_cache = QdrantSemanticCache(
|
||||
collection_name="test_collection",
|
||||
qdrant_api_base="http://test.qdrant.local",
|
||||
qdrant_api_key="test_key",
|
||||
similarity_threshold=0.8,
|
||||
)
|
||||
|
||||
# Mock the async upsert method
|
||||
mock_upsert_response = MagicMock() # Note: .json() should be sync
|
||||
mock_upsert_response.status_code = 200
|
||||
qdrant_cache.async_client.put = AsyncMock(return_value=mock_upsert_response)
|
||||
|
||||
# Mock response to cache
|
||||
response_to_cache = {
|
||||
"id": "test-999",
|
||||
"choices": [{"message": {"content": "Berlin is the capital of Germany."}}]
|
||||
}
|
||||
|
||||
# Mock the async embedding function
|
||||
with patch(
|
||||
"litellm.aembedding",
|
||||
return_value={"data": [{"embedding": [0.2, 0.2, 0.2]}]}
|
||||
):
|
||||
# Test async_set_cache
|
||||
await qdrant_cache.async_set_cache(
|
||||
key="test_key",
|
||||
value=response_to_cache,
|
||||
messages=[{"content": "What is the capital of Germany?"}],
|
||||
metadata={}
|
||||
)
|
||||
|
||||
# Verify async upsert was called
|
||||
qdrant_cache.async_client.put.assert_called()
|
@@ -0,0 +1,122 @@
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_no_ping():
|
||||
"""Patch RedisCache initialization to prevent async ping tasks from being created"""
|
||||
with patch("asyncio.get_running_loop") as mock_get_loop:
|
||||
# Either raise an exception or return a mock that will handle the task creation
|
||||
mock_get_loop.side_effect = RuntimeError("No running event loop")
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize("namespace", [None, "test"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_async_increment(namespace, monkeypatch, redis_no_ping):
|
||||
monkeypatch.setenv("REDIS_HOST", "https://my-test-host")
|
||||
redis_cache = RedisCache(namespace=namespace)
|
||||
# Create an AsyncMock for the Redis client
|
||||
mock_redis_instance = AsyncMock()
|
||||
|
||||
# Make sure the mock can be used as an async context manager
|
||||
mock_redis_instance.__aenter__.return_value = mock_redis_instance
|
||||
mock_redis_instance.__aexit__.return_value = None
|
||||
|
||||
assert redis_cache is not None
|
||||
|
||||
expected_key = "test:test" if namespace else "test"
|
||||
|
||||
with patch.object(
|
||||
redis_cache, "init_async_client", return_value=mock_redis_instance
|
||||
):
|
||||
# Call async_set_cache
|
||||
await redis_cache.async_increment(key=expected_key, value=1)
|
||||
|
||||
# Verify that the set method was called on the mock Redis instance
|
||||
mock_redis_instance.incrbyfloat.assert_called_once_with(
|
||||
name=expected_key, amount=1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_client_init_with_socket_timeout(monkeypatch, redis_no_ping):
|
||||
monkeypatch.setenv("REDIS_HOST", "my-fake-host")
|
||||
redis_cache = RedisCache(socket_timeout=1.0)
|
||||
assert redis_cache.redis_kwargs["socket_timeout"] == 1.0
|
||||
client = redis_cache.init_async_client()
|
||||
assert client is not None
|
||||
assert client.connection_pool.connection_kwargs["socket_timeout"] == 1.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_async_batch_get_cache(monkeypatch, redis_no_ping):
|
||||
monkeypatch.setenv("REDIS_HOST", "https://my-test-host")
|
||||
redis_cache = RedisCache()
|
||||
|
||||
# Create an AsyncMock for the Redis client
|
||||
mock_redis_instance = AsyncMock()
|
||||
|
||||
# Make sure the mock can be used as an async context manager
|
||||
mock_redis_instance.__aenter__.return_value = mock_redis_instance
|
||||
mock_redis_instance.__aexit__.return_value = None
|
||||
|
||||
# Setup the return value for mget
|
||||
mock_redis_instance.mget.return_value = [
|
||||
b'{"key1": "value1"}',
|
||||
None,
|
||||
b'{"key3": "value3"}',
|
||||
]
|
||||
|
||||
test_keys = ["key1", "key2", "key3"]
|
||||
|
||||
with patch.object(
|
||||
redis_cache, "init_async_client", return_value=mock_redis_instance
|
||||
):
|
||||
# Call async_batch_get_cache
|
||||
result = await redis_cache.async_batch_get_cache(key_list=test_keys)
|
||||
|
||||
# Verify mget was called with the correct keys
|
||||
mock_redis_instance.mget.assert_called_once()
|
||||
|
||||
# Check that results were properly decoded
|
||||
assert result["key1"] == {"key1": "value1"}
|
||||
assert result["key2"] is None
|
||||
assert result["key3"] == {"key3": "value3"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_lpop_count_for_older_redis_versions(monkeypatch):
|
||||
"""Test the helper method that handles LPOP with count for Redis versions < 7.0"""
|
||||
monkeypatch.setenv("REDIS_HOST", "https://my-test-host")
|
||||
# Create RedisCache instance
|
||||
redis_cache = RedisCache()
|
||||
|
||||
# Create a mock pipeline
|
||||
mock_pipeline = AsyncMock()
|
||||
# Set up execute to return different values each time
|
||||
mock_pipeline.execute.side_effect = [
|
||||
[b"value1"], # First execute returns first value
|
||||
[b"value2"], # Second execute returns second value
|
||||
]
|
||||
|
||||
# Test the helper method
|
||||
result = await redis_cache.handle_lpop_count_for_older_redis_versions(
|
||||
pipe=mock_pipeline, key="test_key", count=2
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result == [b"value1", b"value2"]
|
||||
assert mock_pipeline.lpop.call_count == 2
|
||||
assert mock_pipeline.execute.call_count == 2
|
@@ -0,0 +1,66 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.caching.redis_cluster_cache import RedisClusterCache
|
||||
|
||||
|
||||
@patch("litellm._redis.init_redis_cluster")
|
||||
def test_redis_cluster_batch_get(mock_init_redis_cluster):
|
||||
"""
|
||||
Test that RedisClusterCache uses mget_nonatomic instead of mget for batch operations
|
||||
"""
|
||||
# Create a mock Redis client
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.mget_nonatomic.return_value = [None, None] # Simulate no cache hits
|
||||
mock_init_redis_cluster.return_value = mock_redis
|
||||
|
||||
# Create RedisClusterCache instance with mock client
|
||||
cache = RedisClusterCache(
|
||||
startup_nodes=[{"host": "localhost", "port": 6379}],
|
||||
password="hello",
|
||||
)
|
||||
|
||||
# Test batch_get_cache
|
||||
keys = ["key1", "key2"]
|
||||
cache.batch_get_cache(keys)
|
||||
|
||||
# Verify mget_nonatomic was called instead of mget
|
||||
mock_redis.mget_nonatomic.assert_called_once()
|
||||
assert not mock_redis.mget.called
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm._redis.init_redis_cluster")
|
||||
async def test_redis_cluster_async_batch_get(mock_init_redis_cluster):
|
||||
"""
|
||||
Test that RedisClusterCache uses mget_nonatomic instead of mget for async batch operations
|
||||
"""
|
||||
# Create a mock Redis client
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.mget_nonatomic.return_value = [None, None] # Simulate no cache hits
|
||||
|
||||
# Create RedisClusterCache instance with mock client
|
||||
cache = RedisClusterCache(
|
||||
startup_nodes=[{"host": "localhost", "port": 6379}],
|
||||
password="hello",
|
||||
)
|
||||
|
||||
# Mock the init_async_client to return our mock redis client
|
||||
cache.init_async_client = MagicMock(return_value=mock_redis)
|
||||
|
||||
# Test async_batch_get_cache
|
||||
keys = ["key1", "key2"]
|
||||
await cache.async_batch_get_cache(keys)
|
||||
|
||||
# Verify mget_nonatomic was called instead of mget
|
||||
mock_redis.mget_nonatomic.assert_called_once()
|
||||
assert not mock_redis.mget.called
|
@@ -0,0 +1,146 @@
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
# Tests for RedisSemanticCache
|
||||
def test_redis_semantic_cache_initialization(monkeypatch):
|
||||
# Mock the redisvl import
|
||||
semantic_cache_mock = MagicMock()
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(CustomTextVectorizer=MagicMock()),
|
||||
},
|
||||
):
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("REDIS_HOST", "localhost")
|
||||
monkeypatch.setenv("REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
|
||||
|
||||
# Initialize the cache with a similarity threshold
|
||||
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
|
||||
|
||||
# Verify the semantic cache was initialized with correct parameters
|
||||
assert redis_semantic_cache.similarity_threshold == 0.8
|
||||
|
||||
# Use pytest.approx for floating point comparison to handle precision issues
|
||||
assert redis_semantic_cache.distance_threshold == pytest.approx(0.2, abs=1e-10)
|
||||
assert redis_semantic_cache.embedding_model == "text-embedding-ada-002"
|
||||
|
||||
# Test initialization with missing similarity_threshold
|
||||
with pytest.raises(ValueError, match="similarity_threshold must be provided"):
|
||||
RedisSemanticCache()
|
||||
|
||||
|
||||
def test_redis_semantic_cache_get_cache(monkeypatch):
|
||||
# Mock the redisvl import and embedding function
|
||||
semantic_cache_mock = MagicMock()
|
||||
custom_vectorizer_mock = MagicMock()
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(
|
||||
CustomTextVectorizer=custom_vectorizer_mock
|
||||
),
|
||||
},
|
||||
):
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("REDIS_HOST", "localhost")
|
||||
monkeypatch.setenv("REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
|
||||
|
||||
# Initialize cache
|
||||
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
|
||||
|
||||
# Mock the llmcache.check method to return a result
|
||||
mock_result = [
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"response": '{"content": "Paris is the capital of France."}',
|
||||
"vector_distance": 0.1, # Distance of 0.1 means similarity of 0.9
|
||||
}
|
||||
]
|
||||
redis_semantic_cache.llmcache.check = MagicMock(return_value=mock_result)
|
||||
|
||||
# Mock the embedding function
|
||||
with patch(
|
||||
"litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
):
|
||||
# Test get_cache with a message
|
||||
result = redis_semantic_cache.get_cache(
|
||||
key="test_key", messages=[{"content": "What is the capital of France?"}]
|
||||
)
|
||||
|
||||
# Verify result is properly parsed
|
||||
assert result == {"content": "Paris is the capital of France."}
|
||||
|
||||
# Verify llmcache.check was called
|
||||
redis_semantic_cache.llmcache.check.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_semantic_cache_async_get_cache(monkeypatch):
|
||||
# Mock the redisvl import
|
||||
semantic_cache_mock = MagicMock()
|
||||
custom_vectorizer_mock = MagicMock()
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(
|
||||
CustomTextVectorizer=custom_vectorizer_mock
|
||||
),
|
||||
},
|
||||
):
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("REDIS_HOST", "localhost")
|
||||
monkeypatch.setenv("REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
|
||||
|
||||
# Initialize cache
|
||||
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
|
||||
|
||||
# Mock the async methods
|
||||
mock_result = [
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"response": '{"content": "Paris is the capital of France."}',
|
||||
"vector_distance": 0.1, # Distance of 0.1 means similarity of 0.9
|
||||
}
|
||||
]
|
||||
|
||||
redis_semantic_cache.llmcache.acheck = AsyncMock(return_value=mock_result)
|
||||
redis_semantic_cache._get_async_embedding = AsyncMock(
|
||||
return_value=[0.1, 0.2, 0.3]
|
||||
)
|
||||
|
||||
# Test async_get_cache with a message
|
||||
result = await redis_semantic_cache.async_get_cache(
|
||||
key="test_key",
|
||||
messages=[{"content": "What is the capital of France?"}],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
# Verify result is properly parsed
|
||||
assert result == {"content": "Paris is the capital of France."}
|
||||
|
||||
# Verify methods were called
|
||||
redis_semantic_cache._get_async_embedding.assert_called_once()
|
||||
redis_semantic_cache.llmcache.acheck.assert_called_once()
|
Reference in New Issue
Block a user