331 lines
12 KiB
Python
331 lines
12 KiB
Python
import asyncio
|
|
import json
|
|
import sys
|
|
import os
|
|
import tempfile
|
|
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
|
import pytest
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../../..")
|
|
) # Adds the parent directory to the system path
|
|
|
|
import litellm
|
|
from litellm.google_genai import (
|
|
generate_content,
|
|
agenerate_content,
|
|
generate_content_stream,
|
|
agenerate_content_stream,
|
|
)
|
|
from google.genai.types import ContentDict, PartDict, GenerateContentResponse
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from litellm.types.utils import StandardLoggingPayload
|
|
|
|
|
|
def load_vertex_ai_credentials(model: str):
|
|
"""Load Vertex AI credentials for tests"""
|
|
# Define the path to the vertex_key.json file
|
|
if "vertex_ai" not in model:
|
|
return None
|
|
print("loading vertex ai credentials")
|
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
|
vertex_key_path = filepath + "/vertex_key.json"
|
|
|
|
# Read the existing content of the file or create an empty dictionary
|
|
try:
|
|
with open(vertex_key_path, "r") as file:
|
|
# Read the file content
|
|
print("Read vertexai file path")
|
|
content = file.read()
|
|
|
|
# If the file is empty or not valid JSON, create an empty dictionary
|
|
if not content or not content.strip():
|
|
service_account_key_data = {}
|
|
else:
|
|
# Attempt to load the existing JSON content
|
|
file.seek(0)
|
|
service_account_key_data = json.load(file)
|
|
except FileNotFoundError:
|
|
# If the file doesn't exist, create an empty dictionary
|
|
service_account_key_data = {}
|
|
|
|
# Update the service_account_key_data with environment variables
|
|
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
|
|
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
|
|
private_key = private_key.replace("\\n", "\n")
|
|
service_account_key_data["private_key_id"] = private_key_id
|
|
service_account_key_data["private_key"] = private_key
|
|
|
|
# Create a temporary file
|
|
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
|
|
# Write the updated content to the temporary files
|
|
json.dump(service_account_key_data, temp_file, indent=2)
|
|
|
|
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
|
|
|
return os.path.abspath(temp_file.name)
|
|
|
|
|
|
class TestCustomLogger(CustomLogger):
|
|
def __init__(
|
|
self,
|
|
):
|
|
self.standard_logging_object: Optional[StandardLoggingPayload] = None
|
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
print("in async_log_success_event")
|
|
print("kwargs=", json.dumps(kwargs, indent=4, default=str))
|
|
self.standard_logging_object = kwargs["standard_logging_object"]
|
|
pass
|
|
|
|
|
|
class BaseGoogleGenAITest:
|
|
"""Base class for Google GenAI generate content tests to reduce code duplication"""
|
|
|
|
@property
|
|
def model_config(self) -> Dict[str, Any]:
|
|
"""Override in subclasses to provide model-specific configuration"""
|
|
raise NotImplementedError("Subclasses must implement model_config")
|
|
|
|
@property
|
|
def _temp_files_to_cleanup(self):
|
|
"""Lazy initialization of temp files list"""
|
|
if not hasattr(self, '_temp_files_list'):
|
|
self._temp_files_list = []
|
|
return self._temp_files_list
|
|
|
|
def cleanup_temp_files(self):
|
|
"""Clean up any temporary files created during testing"""
|
|
for temp_file in self._temp_files_to_cleanup:
|
|
try:
|
|
os.unlink(temp_file)
|
|
except OSError:
|
|
pass # File might already be deleted
|
|
self._temp_files_to_cleanup.clear()
|
|
|
|
|
|
def _validate_non_streaming_response(self, response: Any):
|
|
"""Validate non-streaming response structure"""
|
|
# Handle type checking - response should be a dict for non-streaming
|
|
if isinstance(response, AsyncIterator):
|
|
pytest.fail("Expected non-streaming response but got AsyncIterator")
|
|
|
|
assert isinstance(response, GenerateContentResponse), f"Expected dict response, got {type(response)}"
|
|
print(f"Response: {response.model_dump_json(indent=4)}")
|
|
|
|
# Basic validation - adjust based on actual Google GenAI response structure
|
|
# The exact structure may vary, so we'll be flexible here
|
|
assert response is not None, "Response should not be None"
|
|
|
|
def _validate_streaming_response(self, chunks: List[Any]):
|
|
"""Validate streaming response chunks"""
|
|
assert isinstance(chunks, list), f"Expected list of chunks, got {type(chunks)}"
|
|
assert len(chunks) >= 0, "Should have at least 0 chunks"
|
|
print(f"Total chunks received: {len(chunks)}")
|
|
|
|
def _validate_standard_logging_payload(
|
|
self, slp: StandardLoggingPayload, response: Any
|
|
):
|
|
"""
|
|
Validate that a StandardLoggingPayload object matches the expected response for Google GenAI
|
|
|
|
Args:
|
|
slp (StandardLoggingPayload): The standard logging payload object to validate
|
|
response: The Google GenAI response to compare against
|
|
"""
|
|
# Validate payload exists
|
|
assert slp is not None, "Standard logging payload should not be None"
|
|
|
|
# Validate basic structure
|
|
assert "prompt_tokens" in slp, "Standard logging payload should have prompt_tokens"
|
|
assert "completion_tokens" in slp, "Standard logging payload should have completion_tokens"
|
|
assert "total_tokens" in slp, "Standard logging payload should have total_tokens"
|
|
assert "response_cost" in slp, "Standard logging payload should have response_cost"
|
|
|
|
# Validate token counts are reasonable (non-negative numbers)
|
|
assert slp["prompt_tokens"] >= 0, "Prompt tokens should be non-negative"
|
|
assert slp["completion_tokens"] >= 0, "Completion tokens should be non-negative"
|
|
assert slp["total_tokens"] >= 0, "Total tokens should be non-negative"
|
|
|
|
# Validate spend
|
|
assert slp["response_cost"] >= 0, "Response cost should be non-negative"
|
|
|
|
print(f"Standard logging payload validation passed: prompt_tokens={slp['prompt_tokens']}, completion_tokens={slp['completion_tokens']}, total_tokens={slp['total_tokens']}, cost={slp['response_cost']}")
|
|
|
|
@pytest.mark.parametrize("is_async", [False, True])
|
|
@pytest.mark.asyncio
|
|
async def test_non_streaming_base(self, is_async: bool):
|
|
"""Base test for non-streaming requests (parametrized for sync/async)"""
|
|
request_params = self.model_config
|
|
contents = ContentDict(
|
|
parts=[
|
|
PartDict(
|
|
text="Hello, can you tell me a short joke?"
|
|
)
|
|
],
|
|
role="user",
|
|
)
|
|
temp_file_path = load_vertex_ai_credentials(model=request_params["model"])
|
|
if temp_file_path:
|
|
self._temp_files_to_cleanup.append(temp_file_path)
|
|
|
|
litellm._turn_on_debug()
|
|
|
|
print(f"Testing {'async' if is_async else 'sync'} non-streaming with model config: {request_params}")
|
|
print(f"Contents: {contents}")
|
|
|
|
if is_async:
|
|
print("\n--- Testing async agenerate_content ---")
|
|
response = await agenerate_content(
|
|
contents=contents,
|
|
**request_params
|
|
)
|
|
else:
|
|
print("\n--- Testing sync generate_content ---")
|
|
response = generate_content(
|
|
contents=contents,
|
|
**request_params
|
|
)
|
|
|
|
print(f"{'Async' if is_async else 'Sync'} response: {json.dumps(response, indent=2, default=str)}")
|
|
self._validate_non_streaming_response(response)
|
|
|
|
return response
|
|
|
|
@pytest.mark.parametrize("is_async", [False, True])
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_base(self, is_async: bool):
|
|
"""Base test for streaming requests (parametrized for sync/async)"""
|
|
request_params = self.model_config
|
|
temp_file_path = load_vertex_ai_credentials(model=request_params["model"])
|
|
if temp_file_path:
|
|
self._temp_files_to_cleanup.append(temp_file_path)
|
|
contents = ContentDict(
|
|
parts=[
|
|
PartDict(
|
|
text="Hello, can you tell me a short joke?"
|
|
)
|
|
],
|
|
role="user",
|
|
)
|
|
|
|
print(f"Testing {'async' if is_async else 'sync'} streaming with model config: {request_params}")
|
|
print(f"Contents: {contents}")
|
|
|
|
chunks = []
|
|
|
|
if is_async:
|
|
print("\n--- Testing async agenerate_content_stream ---")
|
|
response = await agenerate_content_stream(
|
|
contents=contents,
|
|
**request_params
|
|
)
|
|
async for chunk in response:
|
|
print(f"Async chunk: {chunk}")
|
|
chunks.append(chunk)
|
|
else:
|
|
print("\n--- Testing sync generate_content_stream ---")
|
|
response = generate_content_stream(
|
|
contents=contents,
|
|
**request_params
|
|
)
|
|
for chunk in response:
|
|
print(f"Sync chunk: {chunk}")
|
|
chunks.append(chunk)
|
|
|
|
self._validate_streaming_response(chunks)
|
|
|
|
return chunks
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_non_streaming_with_logging(self):
|
|
"""Test async non-streaming Google GenAI generate content with logging"""
|
|
litellm._turn_on_debug()
|
|
litellm.logging_callback_manager._reset_all_callbacks()
|
|
litellm.set_verbose = True
|
|
test_custom_logger = TestCustomLogger()
|
|
litellm.callbacks = [test_custom_logger]
|
|
|
|
request_params = self.model_config
|
|
temp_file_path = load_vertex_ai_credentials(model=request_params["model"])
|
|
if temp_file_path:
|
|
self._temp_files_to_cleanup.append(temp_file_path)
|
|
contents = ContentDict(
|
|
parts=[
|
|
PartDict(
|
|
text="Hello, can you tell me a short joke?"
|
|
)
|
|
],
|
|
role="user",
|
|
)
|
|
|
|
print("\n--- Testing async agenerate_content with logging ---")
|
|
response = await agenerate_content(
|
|
contents=contents,
|
|
**request_params
|
|
)
|
|
|
|
print("Google GenAI response=", json.dumps(response, indent=4, default=str))
|
|
|
|
print("sleeping for 5 seconds...")
|
|
await asyncio.sleep(5)
|
|
print(
|
|
"standard logging payload=",
|
|
json.dumps(test_custom_logger.standard_logging_object, indent=4, default=str),
|
|
)
|
|
|
|
assert response is not None
|
|
assert test_custom_logger.standard_logging_object is not None
|
|
|
|
self._validate_standard_logging_payload(
|
|
test_custom_logger.standard_logging_object, response
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_streaming_with_logging(self):
|
|
"""Test async streaming Google GenAI generate content with logging"""
|
|
litellm._turn_on_debug()
|
|
litellm.set_verbose = True
|
|
litellm.logging_callback_manager._reset_all_callbacks()
|
|
test_custom_logger = TestCustomLogger()
|
|
litellm.callbacks = [test_custom_logger]
|
|
|
|
request_params = self.model_config
|
|
temp_file_path = load_vertex_ai_credentials(model=request_params["model"])
|
|
if temp_file_path:
|
|
self._temp_files_to_cleanup.append(temp_file_path)
|
|
contents = ContentDict(
|
|
parts=[
|
|
PartDict(
|
|
text="Hello, can you tell me a short joke?"
|
|
)
|
|
],
|
|
role="user",
|
|
)
|
|
|
|
print("\n--- Testing async agenerate_content_stream with logging ---")
|
|
response = await agenerate_content_stream(
|
|
contents=contents,
|
|
**request_params
|
|
)
|
|
|
|
chunks = []
|
|
async for chunk in response:
|
|
print(f"Google GenAI chunk: {chunk}")
|
|
chunks.append(chunk)
|
|
|
|
print("sleeping for 5 seconds...")
|
|
await asyncio.sleep(5)
|
|
print(
|
|
"standard logging payload=",
|
|
json.dumps(test_custom_logger.standard_logging_object, indent=4, default=str),
|
|
)
|
|
|
|
assert len(chunks) >= 0
|
|
assert test_custom_logger.standard_logging_object is not None
|
|
|
|
self._validate_standard_logging_payload(
|
|
test_custom_logger.standard_logging_object, chunks
|
|
)
|