Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,229 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.llms.gemini.realtime.transformation import GeminiRealtimeConfig
|
||||
from litellm.types.llms.openai import OpenAIRealtimeStreamSessionEvents
|
||||
|
||||
|
||||
def test_gemini_realtime_transformation_session_created():
|
||||
config = GeminiRealtimeConfig()
|
||||
assert config is not None
|
||||
|
||||
session_configuration_request = {
|
||||
"model": "gemini-1.5-flash",
|
||||
"generationConfig": {"responseModalities": ["TEXT"]},
|
||||
}
|
||||
session_configuration_request_str = json.dumps(session_configuration_request)
|
||||
session_created_message = {"setupComplete": {}}
|
||||
|
||||
session_created_message_str = json.dumps(session_created_message)
|
||||
logging_obj = MagicMock()
|
||||
logging_obj.litellm_trace_id = "123"
|
||||
|
||||
transformed_message = config.transform_realtime_response(
|
||||
session_created_message_str,
|
||||
"gemini-1.5-flash",
|
||||
logging_obj,
|
||||
realtime_response_transform_input={
|
||||
"session_configuration_request": session_configuration_request_str,
|
||||
"current_output_item_id": None,
|
||||
"current_response_id": None,
|
||||
"current_conversation_id": None,
|
||||
"current_delta_chunks": [],
|
||||
"current_item_chunks": [],
|
||||
"current_delta_type": None,
|
||||
},
|
||||
)
|
||||
|
||||
print(transformed_message)
|
||||
assert transformed_message["response"][0]["type"] == "session.created"
|
||||
|
||||
|
||||
def test_gemini_realtime_transformation_content_delta():
|
||||
config = GeminiRealtimeConfig()
|
||||
assert config is not None
|
||||
|
||||
session_configuration_request = {
|
||||
"model": "gemini-1.5-flash",
|
||||
"generationConfig": {"responseModalities": ["TEXT"]},
|
||||
}
|
||||
session_configuration_request_str = json.dumps(session_configuration_request)
|
||||
session_created_message = {
|
||||
"serverContent": {
|
||||
"modelTurn": {
|
||||
"parts": [
|
||||
{"text": "Hello, world!"},
|
||||
{"text": "How are you?"},
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
session_created_message_str = json.dumps(session_created_message)
|
||||
logging_obj = MagicMock()
|
||||
logging_obj.litellm_trace_id.return_value = "123"
|
||||
|
||||
returned_object = config.transform_realtime_response(
|
||||
session_created_message_str,
|
||||
"gemini-1.5-flash",
|
||||
logging_obj,
|
||||
realtime_response_transform_input={
|
||||
"session_configuration_request": session_configuration_request_str,
|
||||
"current_output_item_id": None,
|
||||
"current_response_id": None,
|
||||
"current_conversation_id": None,
|
||||
"current_delta_chunks": [],
|
||||
"current_item_chunks": [],
|
||||
"current_delta_type": None,
|
||||
},
|
||||
)
|
||||
transformed_message = returned_object["response"]
|
||||
assert isinstance(transformed_message, list)
|
||||
print(transformed_message)
|
||||
transformed_message_str = json.dumps(transformed_message)
|
||||
assert "Hello, world" in transformed_message_str
|
||||
assert "How are you?" in transformed_message_str
|
||||
print(transformed_message)
|
||||
|
||||
## assert all instances of 'event_id' are unique
|
||||
event_ids = [
|
||||
event["event_id"] for event in transformed_message if "event_id" in event
|
||||
]
|
||||
assert len(event_ids) == len(set(event_ids))
|
||||
## assert all instances of 'response_id' are the same
|
||||
response_ids = [
|
||||
event["response_id"] for event in transformed_message if "response_id" in event
|
||||
]
|
||||
assert len(set(response_ids)) == 1
|
||||
## assert all instances of 'output_item_id' are the same
|
||||
output_item_ids = [
|
||||
event["item_id"] for event in transformed_message if "item_id" in event
|
||||
]
|
||||
assert len(set(output_item_ids)) == 1
|
||||
|
||||
|
||||
def test_gemini_model_turn_event_mapping():
|
||||
from litellm.types.llms.openai import OpenAIRealtimeEventTypes
|
||||
|
||||
config = GeminiRealtimeConfig()
|
||||
assert config is not None
|
||||
|
||||
model_turn_event = {"parts": [{"text": "Hello, world!"}]}
|
||||
openai_event = config.map_model_turn_event(model_turn_event)
|
||||
assert openai_event == OpenAIRealtimeEventTypes.RESPONSE_TEXT_DELTA
|
||||
|
||||
model_turn_event = {
|
||||
"parts": [{"inlineData": {"mimeType": "audio/pcm", "data": "..."}}]
|
||||
}
|
||||
openai_event = config.map_model_turn_event(model_turn_event)
|
||||
assert openai_event == OpenAIRealtimeEventTypes.RESPONSE_AUDIO_DELTA
|
||||
|
||||
model_turn_event = {
|
||||
"parts": [
|
||||
{
|
||||
"text": "Hello, world!",
|
||||
"inlineData": {"mimeType": "audio/pcm", "data": "..."},
|
||||
}
|
||||
]
|
||||
}
|
||||
openai_event = config.map_model_turn_event(model_turn_event)
|
||||
assert openai_event == OpenAIRealtimeEventTypes.RESPONSE_TEXT_DELTA
|
||||
|
||||
|
||||
def test_gemini_realtime_transformation_audio_delta():
|
||||
from litellm.types.llms.openai import OpenAIRealtimeEventTypes
|
||||
|
||||
config = GeminiRealtimeConfig()
|
||||
assert config is not None
|
||||
|
||||
session_configuration_request = {
|
||||
"model": "gemini-1.5-flash",
|
||||
"generationConfig": {"responseModalities": ["AUDIO"]},
|
||||
}
|
||||
session_configuration_request_str = json.dumps(session_configuration_request)
|
||||
|
||||
audio_delta_event = {
|
||||
"serverContent": {
|
||||
"modelTurn": {
|
||||
"parts": [
|
||||
{"inlineData": {"mimeType": "audio/pcm", "data": "my-audio-data"}}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result = config.transform_realtime_response(
|
||||
json.dumps(audio_delta_event),
|
||||
"gemini-1.5-flash",
|
||||
MagicMock(),
|
||||
realtime_response_transform_input={
|
||||
"session_configuration_request": session_configuration_request_str,
|
||||
"current_output_item_id": None,
|
||||
"current_response_id": None,
|
||||
"current_conversation_id": None,
|
||||
"current_delta_chunks": [],
|
||||
"current_item_chunks": [],
|
||||
"current_delta_type": None,
|
||||
},
|
||||
)
|
||||
|
||||
print(result)
|
||||
|
||||
responses = result["response"]
|
||||
|
||||
contains_audio_delta = False
|
||||
for response in responses:
|
||||
if response["type"] == OpenAIRealtimeEventTypes.RESPONSE_AUDIO_DELTA.value:
|
||||
contains_audio_delta = True
|
||||
break
|
||||
assert contains_audio_delta, "Expected audio delta event"
|
||||
|
||||
|
||||
def test_gemini_realtime_transformation_generation_complete():
|
||||
from litellm.types.llms.openai import OpenAIRealtimeEventTypes
|
||||
|
||||
config = GeminiRealtimeConfig()
|
||||
assert config is not None
|
||||
|
||||
session_configuration_request = {
|
||||
"model": "gemini-1.5-flash",
|
||||
"generationConfig": {"responseModalities": ["AUDIO"]},
|
||||
}
|
||||
session_configuration_request_str = json.dumps(session_configuration_request)
|
||||
|
||||
audio_delta_event = {"serverContent": {"generationComplete": True}}
|
||||
|
||||
result = config.transform_realtime_response(
|
||||
json.dumps(audio_delta_event),
|
||||
"gemini-1.5-flash",
|
||||
MagicMock(),
|
||||
realtime_response_transform_input={
|
||||
"session_configuration_request": session_configuration_request_str,
|
||||
"current_output_item_id": "my-output-item-id",
|
||||
"current_response_id": "my-response-id",
|
||||
"current_conversation_id": None,
|
||||
"current_delta_chunks": [],
|
||||
"current_item_chunks": [],
|
||||
"current_delta_type": "audio",
|
||||
},
|
||||
)
|
||||
|
||||
print(result)
|
||||
|
||||
responses = result["response"]
|
||||
|
||||
contains_audio_done_event = False
|
||||
for response in responses:
|
||||
if response["type"] == OpenAIRealtimeEventTypes.RESPONSE_AUDIO_DONE.value:
|
||||
contains_audio_delta = True
|
||||
break
|
||||
assert contains_audio_delta, "Expected audio delta event"
|
@@ -0,0 +1,101 @@
|
||||
import pytest
|
||||
import litellm
|
||||
import os
|
||||
from unittest.mock import patch, Mock
|
||||
from litellm import completion
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_gemini_api_key(monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "fake-gemini-key-for-testing")
|
||||
|
||||
|
||||
def test_gemini_completion():
|
||||
response = completion(
|
||||
model="gemini/gemini-2.0-flash-exp-image-generation",
|
||||
messages=[{"role": "user", "content": "Test message"}],
|
||||
mock_response="Test Message",
|
||||
)
|
||||
assert response.choices[0].message.content is not None
|
||||
|
||||
|
||||
def test_gemini_completion_no_api_key():
|
||||
"""Test Gemini completion fails gracefully when no API key is provided."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
# Remove all API keys
|
||||
for key in ["GOOGLE_API_KEY", "GEMINI_API_KEY"]:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
# Test without mock_response to ensure actual API key validation
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
completion(
|
||||
model="gemini/gemini-1.5-flash",
|
||||
messages=[{"role": "user", "content": "Test message"}],
|
||||
)
|
||||
|
||||
# Check that the exception message contains API key related text
|
||||
error_message = str(exc_info.value).lower()
|
||||
assert any(
|
||||
keyword in error_message
|
||||
for keyword in [
|
||||
"api key",
|
||||
"authentication",
|
||||
"unauthorized",
|
||||
"invalid",
|
||||
"missing",
|
||||
"credential",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_gemini_completion_no_api_key_with_mock():
|
||||
"""Alternative test that properly mocks the API key validation."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
# Remove all API keys
|
||||
for key in ["GOOGLE_API_KEY", "GEMINI_API_KEY"]:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
with patch("litellm.get_secret") as mock_get_secret:
|
||||
mock_get_secret.return_value = None
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
completion(
|
||||
model="gemini/gemini-1.5-flash",
|
||||
messages=[{"role": "user", "content": "Test message"}],
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value).lower()
|
||||
assert any(
|
||||
keyword in error_message
|
||||
for keyword in [
|
||||
"api key",
|
||||
"authentication",
|
||||
"unauthorized",
|
||||
"invalid",
|
||||
"missing",
|
||||
"credential",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("api_key_env", ["GOOGLE_API_KEY", "GEMINI_API_KEY"])
|
||||
def test_gemini_completion_both_env_vars(monkeypatch, api_key_env):
|
||||
"""Test Gemini completion works with both environment variable names."""
|
||||
# Clear all API keys first
|
||||
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
|
||||
# Set the specific API key being tested
|
||||
monkeypatch.setenv(api_key_env, f"fake-{api_key_env.lower()}-for-testing")
|
||||
|
||||
response = completion(
|
||||
model="gemini/gemini-1.5-flash",
|
||||
messages=[{"role": "user", "content": f"Test with {api_key_env}"}],
|
||||
mock_response=f"Mocked response using {api_key_env}",
|
||||
)
|
||||
assert (
|
||||
response["choices"][0]["message"]["content"]
|
||||
== f"Mocked response using {api_key_env}"
|
||||
)
|
@@ -0,0 +1,160 @@
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.llms.gemini.common_utils import GeminiModelInfo, GoogleAIStudioTokenCounter
|
||||
|
||||
|
||||
class TestGeminiModelInfo:
|
||||
"""Test suite for GeminiModelInfo class"""
|
||||
|
||||
def test_process_model_name_normal_cases(self):
|
||||
"""Test process_model_name with normal model names"""
|
||||
gemini_model_info = GeminiModelInfo()
|
||||
|
||||
# Test with normal model names
|
||||
models = [
|
||||
{"name": "models/gemini-1.5-flash"},
|
||||
{"name": "models/gemini-1.5-pro"},
|
||||
{"name": "models/gemini-2.0-flash-exp"},
|
||||
]
|
||||
|
||||
result = gemini_model_info.process_model_name(models)
|
||||
|
||||
expected = [
|
||||
"gemini/gemini-1.5-flash",
|
||||
"gemini/gemini-1.5-pro",
|
||||
"gemini/gemini-2.0-flash-exp",
|
||||
]
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_process_model_name_edge_cases(self):
|
||||
"""Test process_model_name with edge cases that could be affected by strip() vs replace()"""
|
||||
gemini_model_info = GeminiModelInfo()
|
||||
|
||||
# Test edge cases where model names end with characters from "models/"
|
||||
# These would be incorrectly processed if using strip("models/") instead of replace("models/", "")
|
||||
models = [
|
||||
{
|
||||
"name": "models/gemini-1.5-pro"
|
||||
}, # ends with 'o' - would become "gemini-1.5-pr" with strip()
|
||||
{
|
||||
"name": "models/test-model"
|
||||
}, # ends with 'l' - would become "gemini/test-mode" with strip()
|
||||
{
|
||||
"name": "models/custom-models"
|
||||
}, # ends with 's' - would become "gemini/custom-model" with strip()
|
||||
{
|
||||
"name": "models/demo"
|
||||
}, # ends with 'o' - would become "gemini/dem" with strip()
|
||||
]
|
||||
|
||||
result = gemini_model_info.process_model_name(models)
|
||||
|
||||
expected = [
|
||||
"gemini/gemini-1.5-pro", # 'o' should be preserved
|
||||
"gemini/test-model", # 'l' should be preserved
|
||||
"gemini/custom-models", # 's' should be preserved
|
||||
"gemini/demo", # 'o' should be preserved
|
||||
]
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_process_model_name_empty_list(self):
|
||||
"""Test process_model_name with empty list"""
|
||||
gemini_model_info = GeminiModelInfo()
|
||||
|
||||
result = gemini_model_info.process_model_name([])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_process_model_name_no_models_prefix(self):
|
||||
"""Test process_model_name with model names that don't have 'models/' prefix"""
|
||||
gemini_model_info = GeminiModelInfo()
|
||||
|
||||
models = [
|
||||
{"name": "gemini-1.5-flash"}, # No "models/" prefix
|
||||
{"name": "custom-model"},
|
||||
]
|
||||
|
||||
result = gemini_model_info.process_model_name(models)
|
||||
|
||||
expected = [
|
||||
"gemini/gemini-1.5-flash",
|
||||
"gemini/custom-model",
|
||||
]
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestGoogleAIStudioTokenCounter:
|
||||
"""Test suite for GoogleAIStudioTokenCounter class"""
|
||||
|
||||
def test_should_use_token_counting_api(self):
|
||||
"""Test should_use_token_counting_api method with different provider values"""
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
token_counter = GoogleAIStudioTokenCounter()
|
||||
|
||||
# Test with gemini provider - should return True
|
||||
assert token_counter.should_use_token_counting_api(LlmProviders.GEMINI.value) is True
|
||||
|
||||
# Test with other providers - should return False
|
||||
assert token_counter.should_use_token_counting_api(LlmProviders.OPENAI.value) is False
|
||||
assert token_counter.should_use_token_counting_api("anthropic") is False
|
||||
assert token_counter.should_use_token_counting_api("vertex_ai") is False
|
||||
|
||||
# Test with None - should return False
|
||||
assert token_counter.should_use_token_counting_api(None) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_tokens(self):
|
||||
"""Test count_tokens method with mocked API response"""
|
||||
from litellm.types.utils import TokenCountResponse
|
||||
|
||||
token_counter = GoogleAIStudioTokenCounter()
|
||||
|
||||
# Mock the GoogleAIStudioTokenCounter from handler module
|
||||
mock_response = {
|
||||
"totalTokens": 31,
|
||||
"totalBillableCharacters": 96,
|
||||
"promptTokensDetails": [
|
||||
{
|
||||
"modality": "TEXT",
|
||||
"tokenCount": 31
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch('litellm.llms.gemini.count_tokens.handler.GoogleAIStudioTokenCounter.acount_tokens',
|
||||
new_callable=AsyncMock) as mock_acount_tokens:
|
||||
mock_acount_tokens.return_value = mock_response
|
||||
|
||||
# Test data
|
||||
model_to_use = "gemini-1.5-flash"
|
||||
contents = [{"parts": [{"text": "Hello world"}]}]
|
||||
request_model = "gemini/gemini-1.5-flash"
|
||||
|
||||
# Call the method
|
||||
result = await token_counter.count_tokens(
|
||||
model_to_use=model_to_use,
|
||||
messages=None,
|
||||
contents=contents,
|
||||
deployment=None,
|
||||
request_model=request_model
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
assert isinstance(result, TokenCountResponse)
|
||||
assert result.total_tokens == 31
|
||||
assert result.request_model == request_model
|
||||
assert result.model_used == model_to_use
|
||||
assert result.original_response == mock_response
|
||||
|
||||
# Verify the mock was called correctly
|
||||
mock_acount_tokens.assert_called_once_with(
|
||||
model=model_to_use,
|
||||
contents=contents
|
||||
)
|
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
Test Gemini TTS (Text-to-Speech) functionality
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm.llms.gemini.chat.transformation import GoogleAIStudioGeminiConfig
|
||||
from litellm.utils import get_supported_openai_params
|
||||
|
||||
|
||||
class TestGeminiTTSTransformation:
|
||||
"""Test Gemini TTS transformation functionality"""
|
||||
|
||||
def test_gemini_tts_model_detection(self):
|
||||
"""Test that TTS models are correctly identified"""
|
||||
config = GoogleAIStudioGeminiConfig()
|
||||
|
||||
# Test TTS models
|
||||
assert config.is_model_gemini_audio_model("gemini-2.5-flash-preview-tts") == True
|
||||
assert config.is_model_gemini_audio_model("gemini-2.5-pro-preview-tts") == True
|
||||
|
||||
# Test non-TTS models
|
||||
assert config.is_model_gemini_audio_model("gemini-2.5-flash") == False
|
||||
assert config.is_model_gemini_audio_model("gemini-2.5-pro") == False
|
||||
assert config.is_model_gemini_audio_model("gpt-4o-audio-preview") == False
|
||||
|
||||
def test_gemini_tts_supported_params(self):
|
||||
"""Test that audio parameter is included for TTS models"""
|
||||
config = GoogleAIStudioGeminiConfig()
|
||||
|
||||
# Test TTS model
|
||||
params = config.get_supported_openai_params("gemini-2.5-flash-preview-tts")
|
||||
assert "audio" in params
|
||||
|
||||
# Test that other standard params are still included
|
||||
assert "temperature" in params
|
||||
assert "max_tokens" in params
|
||||
assert "modalities" in params
|
||||
|
||||
# Test non-TTS model
|
||||
params_non_tts = config.get_supported_openai_params("gemini-2.5-flash")
|
||||
assert "audio" not in params_non_tts
|
||||
|
||||
def test_gemini_tts_audio_parameter_mapping(self):
|
||||
"""Test audio parameter mapping for TTS models"""
|
||||
config = GoogleAIStudioGeminiConfig()
|
||||
|
||||
non_default_params = {
|
||||
"audio": {
|
||||
"voice": "Kore",
|
||||
"format": "pcm16"
|
||||
}
|
||||
}
|
||||
optional_params = {}
|
||||
|
||||
result = config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gemini-2.5-flash-preview-tts",
|
||||
drop_params=False
|
||||
)
|
||||
|
||||
# Check speech config is created
|
||||
assert "speechConfig" in result
|
||||
assert "voiceConfig" in result["speechConfig"]
|
||||
assert "prebuiltVoiceConfig" in result["speechConfig"]["voiceConfig"]
|
||||
assert result["speechConfig"]["voiceConfig"]["prebuiltVoiceConfig"]["voiceName"] == "Kore"
|
||||
|
||||
# Check response modalities
|
||||
assert "responseModalities" in result
|
||||
assert "AUDIO" in result["responseModalities"]
|
||||
|
||||
def test_gemini_tts_audio_parameter_with_existing_modalities(self):
|
||||
"""Test audio parameter mapping when modalities already exist"""
|
||||
config = GoogleAIStudioGeminiConfig()
|
||||
|
||||
non_default_params = {
|
||||
"audio": {
|
||||
"voice": "Puck",
|
||||
"format": "pcm16"
|
||||
}
|
||||
}
|
||||
optional_params = {
|
||||
"responseModalities": ["TEXT"]
|
||||
}
|
||||
|
||||
result = config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gemini-2.5-flash-preview-tts",
|
||||
drop_params=False
|
||||
)
|
||||
|
||||
# Check that AUDIO is added to existing modalities
|
||||
assert "responseModalities" in result
|
||||
assert "TEXT" in result["responseModalities"]
|
||||
assert "AUDIO" in result["responseModalities"]
|
||||
|
||||
def test_gemini_tts_no_audio_parameter(self):
|
||||
"""Test that non-audio parameters are handled normally"""
|
||||
config = GoogleAIStudioGeminiConfig()
|
||||
|
||||
non_default_params = {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
}
|
||||
optional_params = {}
|
||||
|
||||
result = config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gemini-2.5-flash-preview-tts",
|
||||
drop_params=False
|
||||
)
|
||||
|
||||
# Should not have speech config
|
||||
assert "speechConfig" not in result
|
||||
# Should not automatically add audio modalities
|
||||
assert "responseModalities" not in result
|
||||
|
||||
def test_gemini_tts_invalid_audio_parameter(self):
|
||||
"""Test handling of invalid audio parameter"""
|
||||
config = GoogleAIStudioGeminiConfig()
|
||||
|
||||
non_default_params = {
|
||||
"audio": "invalid_string" # Should be dict
|
||||
}
|
||||
optional_params = {}
|
||||
|
||||
result = config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gemini-2.5-flash-preview-tts",
|
||||
drop_params=False
|
||||
)
|
||||
|
||||
# Should not create speech config for invalid audio param
|
||||
assert "speechConfig" not in result
|
||||
|
||||
def test_gemini_tts_empty_audio_parameter(self):
|
||||
"""Test handling of empty audio parameter"""
|
||||
config = GoogleAIStudioGeminiConfig()
|
||||
|
||||
non_default_params = {
|
||||
"audio": {}
|
||||
}
|
||||
optional_params = {}
|
||||
|
||||
result = config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gemini-2.5-flash-preview-tts",
|
||||
drop_params=False
|
||||
)
|
||||
|
||||
# Should still set response modalities even with empty audio config
|
||||
assert "responseModalities" in result
|
||||
assert "AUDIO" in result["responseModalities"]
|
||||
|
||||
def test_gemini_tts_audio_format_validation(self):
|
||||
"""Test audio format validation for TTS models"""
|
||||
config = GoogleAIStudioGeminiConfig()
|
||||
|
||||
# Test invalid format
|
||||
non_default_params = {
|
||||
"audio": {
|
||||
"voice": "Kore",
|
||||
"format": "wav" # Invalid format
|
||||
}
|
||||
}
|
||||
optional_params = {}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported audio format for Gemini TTS models"):
|
||||
config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model="gemini-2.5-flash-preview-tts",
|
||||
drop_params=False
|
||||
)
|
||||
|
||||
def test_gemini_tts_utils_integration(self):
|
||||
"""Test integration with LiteLLM utils functions"""
|
||||
# Test that get_supported_openai_params works with TTS models
|
||||
params = get_supported_openai_params("gemini-2.5-flash-preview-tts", "gemini")
|
||||
assert "audio" in params
|
||||
|
||||
# Test non-TTS model
|
||||
params_non_tts = get_supported_openai_params("gemini-2.5-flash", "gemini")
|
||||
assert "audio" not in params_non_tts
|
||||
|
||||
|
||||
def test_gemini_tts_completion_mock():
|
||||
"""Test Gemini TTS completion with mocked response"""
|
||||
with patch('litellm.completion') as mock_completion:
|
||||
# Mock a successful TTS response
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Generated audio response"
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
# Test completion call with audio parameter
|
||||
response = litellm.completion(
|
||||
model="gemini-2.5-flash-preview-tts",
|
||||
messages=[{"role": "user", "content": "Say hello"}],
|
||||
audio={"voice": "Kore", "format": "pcm16"}
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.choices[0].message.content is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
Reference in New Issue
Block a user