import os import sys import pytest from litellm.utils import supports_url_context sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system paths from base_llm_unit_tests import BaseLLMChatTest from litellm.llms.vertex_ai.context_caching.transformation import ( separate_cached_messages, transform_openai_messages_to_gemini_context_caching, ) import litellm from litellm import completion import json class TestGoogleAIStudioGemini(BaseLLMChatTest): def get_base_completion_call_args(self) -> dict: return {"model": "gemini/gemini-2.0-flash"} def get_base_completion_call_args_with_reasoning_model(self) -> dict: return {"model": "gemini/gemini-2.5-flash"} def test_tool_call_no_arguments(self, tool_call_no_arguments): """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" from litellm.litellm_core_utils.prompt_templates.factory import ( convert_to_gemini_tool_call_invoke, ) result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments) print(result) @pytest.mark.flaky(retries=3, delay=2) def test_url_context(self): from litellm.utils import supports_url_context os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" litellm.model_cost = litellm.get_model_cost_map(url="") litellm._turn_on_debug() base_completion_call_args = self.get_base_completion_call_args() if not supports_url_context(base_completion_call_args["model"], None): pytest.skip("Model does not support url context") response = self.completion_function( **base_completion_call_args, messages=[ { "role": "user", "content": "Summarize the content of this URL: https://en.wikipedia.org/wiki/Artificial_intelligence", } ], tools=[{"urlContext": {}}], ) assert response is not None assert ( response.model_extra["vertex_ai_url_context_metadata"] is not None ), "URL context metadata should be present" print(f"response={response}") def test_gemini_context_caching_with_ttl(): """Test Gemini context caching with TTL support""" # Test case 1: Basic TTL functionality messages_with_ttl = [ { "role": "system", "content": [ { "type": "text", "text": "Here is the full text of a complex legal agreement" * 400, "cache_control": {"type": "ephemeral", "ttl": "3600s"}, } ], }, { "role": "user", "content": [ { "type": "text", "text": "What are the key terms and conditions in this agreement?", "cache_control": {"type": "ephemeral", "ttl": "7200s"}, } ], }, ] # Test the transformation function directly result = transform_openai_messages_to_gemini_context_caching( model="gemini-1.5-pro", messages=messages_with_ttl, cache_key="test-ttl-cache-key", ) # Verify TTL is properly included in the result assert "ttl" in result assert result["ttl"] == "3600s" # Should use the first valid TTL found assert result["model"] == "models/gemini-1.5-pro" assert result["displayName"] == "test-ttl-cache-key" # Test case 2: Invalid TTL should be ignored messages_invalid_ttl = [ { "role": "user", "content": [ { "type": "text", "text": "Cached content with invalid TTL", "cache_control": {"type": "ephemeral", "ttl": "invalid_ttl"}, } ], } ] result_invalid = transform_openai_messages_to_gemini_context_caching( model="gemini-1.5-pro", messages=messages_invalid_ttl, cache_key="test-invalid-ttl", ) # Verify invalid TTL is not included assert "ttl" not in result_invalid assert result_invalid["model"] == "models/gemini-1.5-pro" assert result_invalid["displayName"] == "test-invalid-ttl" # Test case 3: Messages without TTL should work normally messages_no_ttl = [ { "role": "user", "content": [ { "type": "text", "text": "Cached content without TTL", "cache_control": {"type": "ephemeral"}, } ], } ] result_no_ttl = transform_openai_messages_to_gemini_context_caching( model="gemini-1.5-pro", messages=messages_no_ttl, cache_key="test-no-ttl" ) # Verify no TTL field is present when not specified assert "ttl" not in result_no_ttl assert result_no_ttl["model"] == "models/gemini-1.5-pro" assert result_no_ttl["displayName"] == "test-no-ttl" # Test case 4: Mixed messages with some having TTL messages_mixed = [ { "role": "system", "content": [ { "type": "text", "text": "System message with TTL", "cache_control": {"type": "ephemeral", "ttl": "1800s"}, } ], }, { "role": "user", "content": [ { "type": "text", "text": "User message without TTL", "cache_control": {"type": "ephemeral"}, } ], }, {"role": "assistant", "content": "Assistant response without cache control"}, { "role": "user", "content": [ { "type": "text", "text": "Another user message", "cache_control": {"type": "ephemeral", "ttl": "900s"}, } ], }, ] # Test separation of cached messages cached_messages, non_cached_messages = separate_cached_messages(messages_mixed) assert len(cached_messages) > 0 assert len(non_cached_messages) > 0 # Test transformation with mixed messages result_mixed = transform_openai_messages_to_gemini_context_caching( model="gemini-1.5-pro", messages=messages_mixed, cache_key="test-mixed-ttl" ) # Should pick up the first valid TTL assert "ttl" in result_mixed assert result_mixed["ttl"] == "1800s" assert result_mixed["model"] == "models/gemini-1.5-pro" assert result_mixed["displayName"] == "test-mixed-ttl" def test_gemini_context_caching_separate_messages(): messages = [ # System Message { "role": "system", "content": [ { "type": "text", "text": "Here is the full text of a complex legal agreement" * 400, "cache_control": {"type": "ephemeral"}, } ], }, # marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache. { "role": "user", "content": [ { "type": "text", "text": "What are the key terms and conditions in this agreement?", "cache_control": {"type": "ephemeral"}, } ], }, { "role": "assistant", "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo", }, # The final turn is marked with cache-control, for continuing in followups. { "role": "user", "content": [ { "type": "text", "text": "What are the key terms and conditions in this agreement?", "cache_control": {"type": "ephemeral"}, } ], }, ] cached_messages, non_cached_messages = separate_cached_messages(messages) print(cached_messages) print(non_cached_messages) assert len(cached_messages) > 0, "Cached messages should be present" assert len(non_cached_messages) > 0, "Non-cached messages should be present" def test_gemini_image_generation(): # litellm._turn_on_debug() response = completion( model="gemini/gemini-2.0-flash-exp-image-generation", messages=[{"role": "user", "content": "Generate an image of a cat"}], modalities=["image", "text"], ) assert response.choices[0].message.content is not None def test_gemini_thinking(): litellm._turn_on_debug() from litellm.types.utils import Message, CallTypes from litellm.utils import return_raw_request import json messages = [ { "role": "user", "content": "Explain the concept of Occam's Razor and provide a simple, everyday example", } ] reasoning_content = "I'm thinking about Occam's Razor." assistant_message = Message( content="Okay, let's break down Occam's Razor.", reasoning_content=reasoning_content, role="assistant", tool_calls=None, function_call=None, provider_specific_fields=None, ) messages.append(assistant_message) raw_request = return_raw_request( endpoint=CallTypes.completion, kwargs={ "model": "gemini/gemini-2.5-flash", "messages": messages, }, ) assert reasoning_content in json.dumps(raw_request) response = completion( model="gemini/gemini-2.5-flash", messages=messages, # make sure call works ) print(response.choices[0].message) assert response.choices[0].message.content is not None def test_gemini_thinking_budget_0(): litellm._turn_on_debug() from litellm.types.utils import Message, CallTypes from litellm.utils import return_raw_request import json raw_request = return_raw_request( endpoint=CallTypes.completion, kwargs={ "model": "gemini/gemini-2.5-flash", "messages": [ { "role": "user", "content": "Explain the concept of Occam's Razor and provide a simple, everyday example", } ], "thinking": {"type": "enabled", "budget_tokens": 0}, }, ) print(raw_request) assert "0" in json.dumps(raw_request["raw_request_body"]) def test_gemini_finish_reason(): import os from litellm import completion litellm._turn_on_debug() response = completion( model="gemini/gemini-1.5-pro", messages=[{"role": "user", "content": "give me 3 random words"}], max_tokens=2, ) print(response) assert response.choices[0].finish_reason is not None assert response.choices[0].finish_reason == "length" def test_gemini_url_context(): from litellm import completion litellm._turn_on_debug() url = "https://ai.google.dev/gemini-api/docs/models" prompt = f""" Summarize this document: {url} """ response = completion( model="gemini/gemini-2.5-flash", messages=[{"role": "user", "content": prompt}], tools=[{"urlContext": {}}], ) print(response) message = response.choices[0].message.content assert message is not None url_context_metadata = response.model_extra["vertex_ai_url_context_metadata"] assert url_context_metadata is not None urlMetadata = url_context_metadata[0]["urlMetadata"][0] assert urlMetadata["retrievedUrl"] == url assert urlMetadata["urlRetrievalStatus"] == "URL_RETRIEVAL_STATUS_SUCCESS" @pytest.mark.flaky(retries=3, delay=2) def test_gemini_with_grounding(): from litellm import completion, Usage, stream_chunk_builder litellm._turn_on_debug() litellm.set_verbose = True tools = [{"googleSearch": {}}] # response = completion(model="gemini/gemini-2.0-flash", messages=[{"role": "user", "content": "What is the capital of France?"}], tools=tools) # print(response) # usage: Usage = response.usage # assert usage.prompt_tokens_details.web_search_requests is not None # assert usage.prompt_tokens_details.web_search_requests > 0 ## Check streaming response = completion( model="gemini/gemini-2.0-flash", messages=[{"role": "user", "content": "What is the capital of France?"}], tools=tools, stream=True, stream_options={"include_usage": True}, ) chunks = [] for chunk in response: print(f"received chunk: {chunk}") chunks.append(chunk) print(f"chunks before stream_chunk_builder: {chunks}") assert len(chunks) > 0 complete_response = stream_chunk_builder(chunks) print(complete_response) assert complete_response is not None usage: Usage = complete_response.usage assert usage.prompt_tokens_details.web_search_requests is not None assert usage.prompt_tokens_details.web_search_requests > 0 def test_gemini_with_empty_function_call_arguments(): from litellm import completion litellm._turn_on_debug() tools = [ { "type": "function", "function": { "name": "get_current_weather", "parameters": "", }, } ] response = completion( model="gemini/gemini-2.0-flash", messages=[{"role": "user", "content": "What is the capital of France?"}], tools=tools, ) print(response) assert response.choices[0].message.content is not None @pytest.mark.asyncio async def test_claude_tool_use_with_gemini(): response = await litellm.anthropic.messages.acreate( messages=[ {"role": "user", "content": "Hello, can you tell me the weather in Boston. Please respond with a tool call?"} ], model="gemini/gemini-2.5-flash", stream=True, max_tokens=100, tools=[ { "name": "get_weather", "description": "Get current weather information for a specific location", "input_schema": { "type": "object", "properties": {"location": {"type": "string"}}, }, } ], ) is_content_block_tool_use = False is_partial_json = False has_usage_in_message_delta = False is_content_block_stop = False async for chunk in response: print(chunk) if "content_block_stop" in str(chunk): is_content_block_stop = True # Handle bytes chunks (SSE format) if isinstance(chunk, bytes): chunk_str = chunk.decode("utf-8") # Parse SSE format: event: \ndata: \n\n if "data: " in chunk_str: try: # Extract JSON from data line data_line = [ line for line in chunk_str.split("\n") if line.startswith("data: ") ][0] json_str = data_line[6:] # Remove 'data: ' prefix chunk_data = json.loads(json_str) # Check for tool_use if "tool_use" in json_str: is_content_block_tool_use = True if "partial_json" in json_str: is_partial_json = True if "content_block_stop" in json_str: is_content_block_stop = True # Check for usage in message_delta with stop_reason if ( chunk_data.get("type") == "message_delta" and chunk_data.get("delta", {}).get("stop_reason") is not None and "usage" in chunk_data ): has_usage_in_message_delta = True # Verify usage has the expected structure usage = chunk_data["usage"] assert ( "input_tokens" in usage ), "input_tokens should be present in usage" assert ( "output_tokens" in usage ), "output_tokens should be present in usage" assert isinstance( usage["input_tokens"], int ), "input_tokens should be an integer" assert isinstance( usage["output_tokens"], int ), "output_tokens should be an integer" print(f"Found usage in message_delta: {usage}") except (json.JSONDecodeError, IndexError) as e: # Skip chunks that aren't valid JSON pass else: # Handle dict chunks (fallback) if "tool_use" in str(chunk): is_content_block_tool_use = True if "partial_json" in str(chunk): is_partial_json = True if "content_block_stop" in str(chunk): is_content_block_stop = True assert is_content_block_tool_use, "content_block_tool_use should be present" assert is_partial_json, "partial_json should be present" assert ( has_usage_in_message_delta ), "Usage should be present in message_delta with stop_reason" assert is_content_block_stop, "is_content_block_stop should be present" def test_gemini_tool_use(): data = { "max_tokens": 8192, "stream": True, "temperature": 0.3, "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What's the weather like in Lima, Peru today?"}, ], "model": "gemini/gemini-2.0-flash", "tools": [ { "type": "function", "function": { "name": "get_weather", "description": "Retrieve current weather for a specific location", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "City and country, e.g., Lima, Peru", }, "unit": { "type": "string", "enum": ["celsius", "fahrenheit"], "description": "Temperature unit", }, }, "required": ["location"], }, }, } ], "stream_options": {"include_usage": True}, } response = litellm.completion(**data) print(response) stop_reason = None for chunk in response: print(chunk) if chunk.choices[0].finish_reason: stop_reason = chunk.choices[0].finish_reason assert stop_reason is not None assert stop_reason == "tool_calls"