Added LiteLLM to the stack

This commit is contained in:
2025-08-18 09:40:50 +00:00
parent 0648c1968c
commit d220b04e32
2682 changed files with 533609 additions and 1 deletions

View File

@@ -0,0 +1,330 @@
import asyncio
import json
import sys
import os
import tempfile
from typing import Any, AsyncIterator, Dict, List, Optional, Union
import pytest
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.google_genai import (
generate_content,
agenerate_content,
generate_content_stream,
agenerate_content_stream,
)
from google.genai.types import ContentDict, PartDict, GenerateContentResponse
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import StandardLoggingPayload
def load_vertex_ai_credentials(model: str):
"""Load Vertex AI credentials for tests"""
# Define the path to the vertex_key.json file
if "vertex_ai" not in model:
return None
print("loading vertex ai credentials")
filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + "/vertex_key.json"
# Read the existing content of the file or create an empty dictionary
try:
with open(vertex_key_path, "r") as file:
# Read the file content
print("Read vertexai file path")
content = file.read()
# If the file is empty or not valid JSON, create an empty dictionary
if not content or not content.strip():
service_account_key_data = {}
else:
# Attempt to load the existing JSON content
file.seek(0)
service_account_key_data = json.load(file)
except FileNotFoundError:
# If the file doesn't exist, create an empty dictionary
service_account_key_data = {}
# Update the service_account_key_data with environment variables
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
private_key = private_key.replace("\\n", "\n")
service_account_key_data["private_key_id"] = private_key_id
service_account_key_data["private_key"] = private_key
# Create a temporary file
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
# Write the updated content to the temporary files
json.dump(service_account_key_data, temp_file, indent=2)
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
return os.path.abspath(temp_file.name)
class TestCustomLogger(CustomLogger):
def __init__(
self,
):
self.standard_logging_object: Optional[StandardLoggingPayload] = None
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print("in async_log_success_event")
print("kwargs=", json.dumps(kwargs, indent=4, default=str))
self.standard_logging_object = kwargs["standard_logging_object"]
pass
class BaseGoogleGenAITest:
"""Base class for Google GenAI generate content tests to reduce code duplication"""
@property
def model_config(self) -> Dict[str, Any]:
"""Override in subclasses to provide model-specific configuration"""
raise NotImplementedError("Subclasses must implement model_config")
@property
def _temp_files_to_cleanup(self):
"""Lazy initialization of temp files list"""
if not hasattr(self, '_temp_files_list'):
self._temp_files_list = []
return self._temp_files_list
def cleanup_temp_files(self):
"""Clean up any temporary files created during testing"""
for temp_file in self._temp_files_to_cleanup:
try:
os.unlink(temp_file)
except OSError:
pass # File might already be deleted
self._temp_files_to_cleanup.clear()
def _validate_non_streaming_response(self, response: Any):
"""Validate non-streaming response structure"""
# Handle type checking - response should be a dict for non-streaming
if isinstance(response, AsyncIterator):
pytest.fail("Expected non-streaming response but got AsyncIterator")
assert isinstance(response, GenerateContentResponse), f"Expected dict response, got {type(response)}"
print(f"Response: {response.model_dump_json(indent=4)}")
# Basic validation - adjust based on actual Google GenAI response structure
# The exact structure may vary, so we'll be flexible here
assert response is not None, "Response should not be None"
def _validate_streaming_response(self, chunks: List[Any]):
"""Validate streaming response chunks"""
assert isinstance(chunks, list), f"Expected list of chunks, got {type(chunks)}"
assert len(chunks) >= 0, "Should have at least 0 chunks"
print(f"Total chunks received: {len(chunks)}")
def _validate_standard_logging_payload(
self, slp: StandardLoggingPayload, response: Any
):
"""
Validate that a StandardLoggingPayload object matches the expected response for Google GenAI
Args:
slp (StandardLoggingPayload): The standard logging payload object to validate
response: The Google GenAI response to compare against
"""
# Validate payload exists
assert slp is not None, "Standard logging payload should not be None"
# Validate basic structure
assert "prompt_tokens" in slp, "Standard logging payload should have prompt_tokens"
assert "completion_tokens" in slp, "Standard logging payload should have completion_tokens"
assert "total_tokens" in slp, "Standard logging payload should have total_tokens"
assert "response_cost" in slp, "Standard logging payload should have response_cost"
# Validate token counts are reasonable (non-negative numbers)
assert slp["prompt_tokens"] >= 0, "Prompt tokens should be non-negative"
assert slp["completion_tokens"] >= 0, "Completion tokens should be non-negative"
assert slp["total_tokens"] >= 0, "Total tokens should be non-negative"
# Validate spend
assert slp["response_cost"] >= 0, "Response cost should be non-negative"
print(f"Standard logging payload validation passed: prompt_tokens={slp['prompt_tokens']}, completion_tokens={slp['completion_tokens']}, total_tokens={slp['total_tokens']}, cost={slp['response_cost']}")
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_non_streaming_base(self, is_async: bool):
"""Base test for non-streaming requests (parametrized for sync/async)"""
request_params = self.model_config
contents = ContentDict(
parts=[
PartDict(
text="Hello, can you tell me a short joke?"
)
],
role="user",
)
temp_file_path = load_vertex_ai_credentials(model=request_params["model"])
if temp_file_path:
self._temp_files_to_cleanup.append(temp_file_path)
litellm._turn_on_debug()
print(f"Testing {'async' if is_async else 'sync'} non-streaming with model config: {request_params}")
print(f"Contents: {contents}")
if is_async:
print("\n--- Testing async agenerate_content ---")
response = await agenerate_content(
contents=contents,
**request_params
)
else:
print("\n--- Testing sync generate_content ---")
response = generate_content(
contents=contents,
**request_params
)
print(f"{'Async' if is_async else 'Sync'} response: {json.dumps(response, indent=2, default=str)}")
self._validate_non_streaming_response(response)
return response
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_streaming_base(self, is_async: bool):
"""Base test for streaming requests (parametrized for sync/async)"""
request_params = self.model_config
temp_file_path = load_vertex_ai_credentials(model=request_params["model"])
if temp_file_path:
self._temp_files_to_cleanup.append(temp_file_path)
contents = ContentDict(
parts=[
PartDict(
text="Hello, can you tell me a short joke?"
)
],
role="user",
)
print(f"Testing {'async' if is_async else 'sync'} streaming with model config: {request_params}")
print(f"Contents: {contents}")
chunks = []
if is_async:
print("\n--- Testing async agenerate_content_stream ---")
response = await agenerate_content_stream(
contents=contents,
**request_params
)
async for chunk in response:
print(f"Async chunk: {chunk}")
chunks.append(chunk)
else:
print("\n--- Testing sync generate_content_stream ---")
response = generate_content_stream(
contents=contents,
**request_params
)
for chunk in response:
print(f"Sync chunk: {chunk}")
chunks.append(chunk)
self._validate_streaming_response(chunks)
return chunks
@pytest.mark.asyncio
async def test_async_non_streaming_with_logging(self):
"""Test async non-streaming Google GenAI generate content with logging"""
litellm._turn_on_debug()
litellm.logging_callback_manager._reset_all_callbacks()
litellm.set_verbose = True
test_custom_logger = TestCustomLogger()
litellm.callbacks = [test_custom_logger]
request_params = self.model_config
temp_file_path = load_vertex_ai_credentials(model=request_params["model"])
if temp_file_path:
self._temp_files_to_cleanup.append(temp_file_path)
contents = ContentDict(
parts=[
PartDict(
text="Hello, can you tell me a short joke?"
)
],
role="user",
)
print("\n--- Testing async agenerate_content with logging ---")
response = await agenerate_content(
contents=contents,
**request_params
)
print("Google GenAI response=", json.dumps(response, indent=4, default=str))
print("sleeping for 5 seconds...")
await asyncio.sleep(5)
print(
"standard logging payload=",
json.dumps(test_custom_logger.standard_logging_object, indent=4, default=str),
)
assert response is not None
assert test_custom_logger.standard_logging_object is not None
self._validate_standard_logging_payload(
test_custom_logger.standard_logging_object, response
)
@pytest.mark.asyncio
async def test_async_streaming_with_logging(self):
"""Test async streaming Google GenAI generate content with logging"""
litellm._turn_on_debug()
litellm.set_verbose = True
litellm.logging_callback_manager._reset_all_callbacks()
test_custom_logger = TestCustomLogger()
litellm.callbacks = [test_custom_logger]
request_params = self.model_config
temp_file_path = load_vertex_ai_credentials(model=request_params["model"])
if temp_file_path:
self._temp_files_to_cleanup.append(temp_file_path)
contents = ContentDict(
parts=[
PartDict(
text="Hello, can you tell me a short joke?"
)
],
role="user",
)
print("\n--- Testing async agenerate_content_stream with logging ---")
response = await agenerate_content_stream(
contents=contents,
**request_params
)
chunks = []
async for chunk in response:
print(f"Google GenAI chunk: {chunk}")
chunks.append(chunk)
print("sleeping for 5 seconds...")
await asyncio.sleep(5)
print(
"standard logging payload=",
json.dumps(test_custom_logger.standard_logging_object, indent=4, default=str),
)
assert len(chunks) >= 0
assert test_custom_logger.standard_logging_object is not None
self._validate_standard_logging_payload(
test_custom_logger.standard_logging_object, chunks
)