Added LiteLLM to the stack

This commit is contained in:
2025-08-18 09:40:50 +00:00
parent 0648c1968c
commit d220b04e32
2682 changed files with 533609 additions and 1 deletions

View File

@@ -0,0 +1,286 @@
"""
Test reasoning content preservation in Responses API transformation
"""
from unittest.mock import AsyncMock
from litellm.types.utils import ModelResponseStream, StreamingChoices, Delta
from litellm.responses.litellm_completion_transformation.streaming_iterator import (
LiteLLMCompletionStreamingIterator,
)
from litellm.responses.litellm_completion_transformation.transformation import (
LiteLLMCompletionResponsesConfig,
)
from litellm.types.utils import ModelResponse, Choices, Message
class TestReasoningContentStreaming:
"""Test reasoning content preservation during streaming"""
def test_reasoning_content_in_delta(self):
"""Test that reasoning content is preserved in streaming deltas"""
# Setup
chunk = ModelResponseStream(
id="test-id",
created=1234567890,
model="test-model",
object="chat.completion.chunk",
choices=[
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(
content="",
role="assistant",
reasoning_content="Let me think about this problem...",
),
)
],
)
mock_stream = AsyncMock()
iterator = LiteLLMCompletionStreamingIterator(
litellm_custom_stream_wrapper=mock_stream,
request_input="Test input",
responses_api_request={},
)
# Execute
transformed_chunk = (
iterator._transform_chat_completion_chunk_to_response_api_chunk(chunk)
)
# Assert
assert transformed_chunk.delta == "Let me think about this problem..."
assert transformed_chunk.type == "response.reasoning_summary_text.delta"
def test_mixed_content_and_reasoning(self):
"""Test handling of both content and reasoning content"""
# Setup
chunk = ModelResponseStream(
id="test-id",
created=1234567890,
model="test-model",
object="chat.completion.chunk",
choices=[
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(
content="Here is the answer",
role="assistant",
reasoning_content="First, let me analyze...",
),
)
],
)
mock_stream = AsyncMock()
iterator = LiteLLMCompletionStreamingIterator(
litellm_custom_stream_wrapper=mock_stream,
request_input="Test input",
responses_api_request={},
)
# Execute
transformed_chunk = (
iterator._transform_chat_completion_chunk_to_response_api_chunk(chunk)
)
# Assert
assert transformed_chunk.delta == "First, let me analyze..."
assert transformed_chunk.type == "response.reasoning_summary_text.delta"
def test_no_reasoning_content(self):
"""Test handling when no reasoning content is present"""
# Setup
chunk = ModelResponseStream(
id="test-id",
created=1234567890,
model="test-model",
object="chat.completion.chunk",
choices=[
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(
content="Regular content only",
role="assistant",
),
)
],
)
mock_stream = AsyncMock()
iterator = LiteLLMCompletionStreamingIterator(
litellm_custom_stream_wrapper=mock_stream,
request_input="Test input",
responses_api_request={},
)
# Execute
transformed_chunk = (
iterator._transform_chat_completion_chunk_to_response_api_chunk(chunk)
)
# Assert
assert transformed_chunk.delta == "Regular content only"
assert transformed_chunk.type == "response.output_text.delta"
class TestReasoningContentFinalResponse:
"""Test reasoning content preservation in final response transformation"""
def test_reasoning_content_in_final_response(self):
"""Test that reasoning content is included in final response"""
# Setup
response = ModelResponse(
id="test-id",
created=1234567890,
model="test-model",
object="chat.completion",
choices=[
Choices(
finish_reason="stop",
index=0,
message=Message(
content="Here is my answer",
role="assistant",
reasoning_content="Let me think step by step about this problem...",
),
)
],
)
# Execute
responses_api_response = LiteLLMCompletionResponsesConfig.transform_chat_completion_response_to_responses_api_response(
request_input="Test input",
responses_api_request={},
chat_completion_response=response,
)
# Assert
assert hasattr(responses_api_response, "output")
assert len(responses_api_response.output) > 0
reasoning_items = [
item for item in responses_api_response.output if item.type == "reasoning"
]
assert len(reasoning_items) > 0, "No reasoning item found in output"
reasoning_item = reasoning_items[0]
assert (
reasoning_item.content[0].text
== "Let me think step by step about this problem..."
)
def test_no_reasoning_content_in_response(self):
"""Test handling when no reasoning content in response"""
# Setup
response = ModelResponse(
id="test-id",
created=1234567890,
model="test-model",
object="chat.completion",
choices=[
Choices(
finish_reason="stop",
index=0,
message=Message(
content="Simple answer",
role="assistant",
),
)
],
)
# Execute
responses_api_response = LiteLLMCompletionResponsesConfig.transform_chat_completion_response_to_responses_api_response(
request_input="Test input",
responses_api_request={},
chat_completion_response=response,
)
# Assert
reasoning_items = [
item for item in responses_api_response.output if item.type == "reasoning"
]
assert (
len(reasoning_items) == 0
), "Should have no reasoning items when no reasoning content present"
def test_multiple_choices_with_reasoning(self):
"""Test handling multiple choices, first with reasoning content"""
# Setup
response = ModelResponse(
id="test-id",
created=1234567890,
model="test-model",
object="chat.completion",
choices=[
Choices(
finish_reason="stop",
index=0,
message=Message(
content="First answer",
role="assistant",
reasoning_content="Reasoning for first answer",
),
),
Choices(
finish_reason="stop",
index=1,
message=Message(
content="Second answer",
role="assistant",
reasoning_content="Reasoning for second answer",
),
),
],
)
# Execute
responses_api_response = LiteLLMCompletionResponsesConfig.transform_chat_completion_response_to_responses_api_response(
request_input="Test input",
responses_api_request={},
chat_completion_response=response,
)
# Assert
reasoning_items = [
item for item in responses_api_response.output if item.type == "reasoning"
]
assert len(reasoning_items) == 1, "Should have exactly one reasoning item"
assert reasoning_items[0].content[0].text == "Reasoning for first answer"
def test_streaming_chunk_id_raw():
"""Test that streaming chunk IDs are raw (not encoded) to match OpenAI format"""
chunk = ModelResponseStream(
id="chunk-123",
created=1234567890,
model="test-model",
object="chat.completion.chunk",
choices=[
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(content="Hello", role="assistant"),
)
],
)
iterator = LiteLLMCompletionStreamingIterator(
litellm_custom_stream_wrapper=AsyncMock(),
request_input="Test input",
responses_api_request={},
custom_llm_provider="openai",
litellm_metadata={"model_info": {"id": "gpt-4"}},
)
result = iterator._transform_chat_completion_chunk_to_response_api_chunk(chunk)
# Streaming chunk IDs should be raw (like OpenAI's msg_xxx format)
assert result.item_id == "chunk-123" # Should be raw, not encoded
assert not result.item_id.startswith("resp_") # Should NOT have resp_ prefix

View File

@@ -0,0 +1,366 @@
import json
import os
import sys
from unittest.mock import AsyncMock, patch
import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm.responses.litellm_completion_transformation import session_handler
from litellm.responses.litellm_completion_transformation.session_handler import (
ResponsesSessionHandler,
)
@pytest.mark.asyncio
async def test_get_chat_completion_message_history_for_previous_response_id():
"""
Test get_chat_completion_message_history_for_previous_response_id with mock data
"""
# Mock data based on the provided spend logs (simplified version)
mock_spend_logs = [
{
"request_id": "chatcmpl-935b8dad-fdc2-466e-a8ca-e26e5a8a21bb",
"call_type": "aresponses",
"api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
"spend": 0.004803,
"total_tokens": 329,
"prompt_tokens": 11,
"completion_tokens": 318,
"startTime": "2025-05-30T03:17:06.703+00:00",
"endTime": "2025-05-30T03:17:11.894+00:00",
"model": "claude-3-5-sonnet-latest",
"session_id": "a96757c4-c6dc-4c76-b37e-e7dfa526b701",
"proxy_server_request": {
"input": "who is Michael Jordan",
"model": "anthropic/claude-3-5-sonnet-latest",
},
"response": {
"id": "chatcmpl-935b8dad-fdc2-466e-a8ca-e26e5a8a21bb",
"model": "claude-3-5-sonnet-20241022",
"object": "chat.completion",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Michael Jordan (born February 17, 1963) is widely considered the greatest basketball player of all time. Here are some key points about him...",
"tool_calls": None,
"function_call": None,
},
"finish_reason": "stop",
}
],
"created": 1748575031,
"usage": {
"total_tokens": 329,
"prompt_tokens": 11,
"completion_tokens": 318,
},
},
"status": "success",
},
{
"request_id": "chatcmpl-370760c9-39fa-4db7-b034-d1f8d933c935",
"call_type": "aresponses",
"api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
"spend": 0.010437,
"total_tokens": 967,
"prompt_tokens": 339,
"completion_tokens": 628,
"startTime": "2025-05-30T03:17:28.600+00:00",
"endTime": "2025-05-30T03:17:39.921+00:00",
"model": "claude-3-5-sonnet-latest",
"session_id": "a96757c4-c6dc-4c76-b37e-e7dfa526b701",
"proxy_server_request": {
"input": "can you tell me more about him",
"model": "anthropic/claude-3-5-sonnet-latest",
"previous_response_id": "resp_bGl0ZWxsbTpjdXN0b21fbGxtX3Byb3ZpZGVyOmFudGhyb3BpYzttb2RlbF9pZDplMGYzMDJhMTQxMmU3ODQ3MGViYjI4Y2JlZDAxZmZmNWY4OGMwZDMzMWM2NjdlOWYyYmE0YjQxM2M2ZmJkMjgyO3Jlc3BvbnNlX2lkOmNoYXRjbXBsLTkzNWI4ZGFkLWZkYzItNDY2ZS1hOGNhLWUyNmU1YThhMjFiYg==",
},
"response": {
"id": "chatcmpl-370760c9-39fa-4db7-b034-d1f8d933c935",
"model": "claude-3-5-sonnet-20241022",
"object": "chat.completion",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Here's more detailed information about Michael Jordan...",
"tool_calls": None,
"function_call": None,
},
"finish_reason": "stop",
}
],
"created": 1748575059,
"usage": {
"total_tokens": 967,
"prompt_tokens": 339,
"completion_tokens": 628,
},
},
"status": "success",
},
]
# Mock the get_all_spend_logs_for_previous_response_id method
with patch.object(
ResponsesSessionHandler,
"get_all_spend_logs_for_previous_response_id",
new_callable=AsyncMock,
) as mock_get_spend_logs:
mock_get_spend_logs.return_value = mock_spend_logs
# Test the function
previous_response_id = "chatcmpl-935b8dad-fdc2-466e-a8ca-e26e5a8a21bb"
result = await ResponsesSessionHandler.get_chat_completion_message_history_for_previous_response_id(
previous_response_id
)
# Verify the mock was called with correct parameters
mock_get_spend_logs.assert_called_once_with(previous_response_id)
# Verify the returned ChatCompletionSession structure
assert "messages" in result
assert "litellm_session_id" in result
# Verify session_id is extracted correctly
assert result["litellm_session_id"] == "a96757c4-c6dc-4c76-b37e-e7dfa526b701"
# Verify messages structure
messages = result["messages"]
assert len(messages) == 4 # 2 user messages + 2 assistant messages
# Check the message sequence
# First user message
assert messages[0].get("role") == "user"
assert messages[0].get("content") == "who is Michael Jordan"
# First assistant response
assert messages[1].get("role") == "assistant"
content_1 = messages[1].get("content", "")
if isinstance(content_1, str):
assert "Michael Jordan" in content_1
assert content_1.startswith("Michael Jordan (born February 17, 1963)")
# Second user message
assert messages[2].get("role") == "user"
assert messages[2].get("content") == "can you tell me more about him"
# Second assistant response
assert messages[3].get("role") == "assistant"
content_3 = messages[3].get("content", "")
if isinstance(content_3, str):
assert "Here's more detailed information about Michael Jordan" in content_3
@pytest.mark.asyncio
async def test_get_chat_completion_message_history_empty_spend_logs():
"""
Test get_chat_completion_message_history_for_previous_response_id with empty spend logs
"""
with patch.object(
ResponsesSessionHandler,
"get_all_spend_logs_for_previous_response_id",
new_callable=AsyncMock,
) as mock_get_spend_logs:
mock_get_spend_logs.return_value = []
previous_response_id = "non-existent-id"
result = await ResponsesSessionHandler.get_chat_completion_message_history_for_previous_response_id(
previous_response_id
)
# Verify empty result structure
assert result.get("messages") == []
assert result.get("litellm_session_id") is None
@pytest.mark.asyncio
async def test_e2e_cold_storage_successful_retrieval():
"""
Test end-to-end cold storage functionality with successful retrieval of full proxy request from cold storage.
"""
# Mock spend logs with cold storage object key in metadata
mock_spend_logs = [
{
"request_id": "chatcmpl-test-123",
"session_id": "session-456",
"metadata": '{"cold_storage_object_key": "s3://test-bucket/requests/session_456_req1.json"}',
"proxy_server_request": '{"litellm_truncated": true}', # Truncated payload
"response": {
"id": "chatcmpl-test-123",
"object": "chat.completion",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "I am an AI assistant."
}
}
]
}
}
]
# Full proxy request data from cold storage
full_proxy_request = {
"input": "Hello, who are you?",
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello, who are you?"}]
}
with patch.object(
ResponsesSessionHandler,
"get_all_spend_logs_for_previous_response_id",
new_callable=AsyncMock,
) as mock_get_spend_logs, \
patch.object(session_handler, "COLD_STORAGE_HANDLER") as mock_cold_storage, \
patch("litellm.proxy.spend_tracking.cold_storage_handler.ColdStorageHandler._get_configured_cold_storage_custom_logger", return_value="s3"):
# Setup mocks
mock_get_spend_logs.return_value = mock_spend_logs
mock_cold_storage.get_proxy_server_request_from_cold_storage_with_object_key = AsyncMock(return_value=full_proxy_request)
# Call the main function
result = await ResponsesSessionHandler.get_chat_completion_message_history_for_previous_response_id(
"chatcmpl-test-123"
)
# Verify cold storage was called with correct object key
mock_cold_storage.get_proxy_server_request_from_cold_storage_with_object_key.assert_called_once_with(
object_key="s3://test-bucket/requests/session_456_req1.json"
)
# Verify result structure
assert result.get("litellm_session_id") == "session-456"
assert len(result.get("messages", [])) >= 1 # At least the assistant response
@pytest.mark.asyncio
async def test_e2e_cold_storage_fallback_to_truncated_payload():
"""
Test end-to-end cold storage functionality when object key is missing, falling back to truncated payload.
"""
# Mock spend logs without cold storage object key
mock_spend_logs = [
{
"request_id": "chatcmpl-test-789",
"session_id": "session-999",
"metadata": '{"user_api_key": "test-key"}', # No cold storage object key
"proxy_server_request": '{"input": "Truncated message", "model": "gpt-4"}', # Regular payload
"response": {
"id": "chatcmpl-test-789",
"object": "chat.completion",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a response."
}
}
]
}
}
]
with patch.object(
ResponsesSessionHandler,
"get_all_spend_logs_for_previous_response_id",
new_callable=AsyncMock,
) as mock_get_spend_logs, \
patch.object(session_handler, "COLD_STORAGE_HANDLER") as mock_cold_storage:
# Setup mocks
mock_get_spend_logs.return_value = mock_spend_logs
# Call the main function
result = await ResponsesSessionHandler.get_chat_completion_message_history_for_previous_response_id(
"chatcmpl-test-789"
)
# Verify cold storage was NOT called since no object key in metadata
mock_cold_storage.get_proxy_server_request_from_cold_storage_with_object_key.assert_not_called()
# Verify result structure
assert result.get("litellm_session_id") == "session-999"
assert len(result.get("messages", [])) >= 1 # At least the assistant response
@pytest.mark.asyncio
async def test_should_check_cold_storage_for_full_payload():
"""
Test _should_check_cold_storage_for_full_payload returns True for proxy server requests with truncated content
"""
# Test case 1: Proxy server request with truncated PDF content (should return True)
proxy_request_with_truncated_pdf = {
"input": [
{
"role": "user",
"type": "message",
"content": [
{
"text": "what was datadogs largest source of operating cash ? quote the section you saw ",
"type": "input_text"
},
{
"type": "input_image",
"image_url": "data:application/pdf;base64,JVBERi0xLjcKJYGBgYEKCjcgMCBvYmoKPDwKL0ZpbHRlciAvRmxhdGVEZWNvZGUKL0xlbmd0aCA1NjcxCj4+CnN0cmVhbQp4nO1dW4/cthV+31+h5wKVeb8AhoG9Bn0I0DYL9NlInQBFHKSpA+Tnl5qRNNRIn8ij4WpnbdqAsRaX90Oe23cOWyH94U/Dwt+/ttF/neKt59675sfPN/+9Ubp1MvwRjfAtN92fRkgn2+5jI5Xyre9++fdPN//6S/NrqCFax4XqvnVtn/631FLogjfd339+1xx/+P3nm3ffyebn/92ww2Bc46zRrGv/p5vWMOmb+N9Qb/4xtOEazn2oH3rjfV3fDTj+t6s7+zjU5XFdFxo9fPs8/CiaX26cYmc/svDjhlF+Pv7QNdT30/9wbI8dFjK0cfzhUO8wPjaOr/Eq/v/d8827vzfv37/7/v5vD6HKhw93D/c3755UI3jYuOb5p7Dsh53nYQtZqyUXugn71Dx/vnnPmHQfmuf/3HDdKhY2z8jwq8//broSjkrE/aHEtZIxZhQ/VbHHKqoVQhsv7KmKgyUWdcPkoUSHaYgwmmhk5liFt85w45WcDUC0RgntpTp1o44lMhCpl95eNOZ+AI/f3988Pp9tAV/dAu5VKz0Ls+SB0vstgNNZWTUDtw1vqC65oahKctUW5gl3etg1EnXSCWplzHewG7gDoq9jWu28DYuzPB3NucmZzm3UmvFcLI/aMpcxT0w1AvUi2aFEsJZLprxMF0xIQzmXQQArRwA2Fo/YCm7P13LxdIrodIa7QIojL5zekqblloeuwkiGI3o7jk+zMEJ9vqW+NWFDtTDnU7KtCpet1WZGHqyVxjLNxPlcdcudsU648xnNOxmIY95Wf3JduAidF3q+bgtVUC/s6VAgZ/TUE9q8oD+DCwM2qA/UFD/eWqY1RrFQ61QgUYFDBRYV3GOKkSPFaEQxQqqWWRcGzZ3u... (litellm_truncated 1197576 chars)"
}
]
}
],
"model": "anthropic/claude-3-7-sonnet-20250219",
"stream": True,
"litellm_trace_id": "16b86861-c120-4ecb-865b-4d2238bfd8f0"
}
# Test case 2: Regular proxy request without truncation (should return False)
proxy_request_regular = {
"input": [
{
"role": "user",
"type": "message",
"content": "Hello, this is a regular message"
}
],
"model": "anthropic/claude-3-7-sonnet-20250219",
"stream": True
}
# Test case 3: Empty request (should return True)
proxy_request_empty = {}
# Test case 4: None request (should return True)
proxy_request_none = None
with patch("litellm.proxy.spend_tracking.cold_storage_handler.ColdStorageHandler._get_configured_cold_storage_custom_logger", return_value="s3"):
# Test case 1: Should return True for truncated content
result1 = ResponsesSessionHandler._should_check_cold_storage_for_full_payload(proxy_request_with_truncated_pdf)
assert result1 == True, "Should return True for proxy request with truncated PDF content"
# Test case 2: Should return False for regular content
result2 = ResponsesSessionHandler._should_check_cold_storage_for_full_payload(proxy_request_regular)
assert result2 == False, "Should return False for regular proxy request without truncation"
# Test case 3: Should return True for empty request
result3 = ResponsesSessionHandler._should_check_cold_storage_for_full_payload(proxy_request_empty)
assert result3 == True, "Should return True for empty proxy request"
# Test case 4: Should return True for None request
result4 = ResponsesSessionHandler._should_check_cold_storage_for_full_payload(proxy_request_none)
assert result4 == True, "Should return True for None proxy request"
# Test case 5: Should return False when cold storage is not configured
with patch("litellm.proxy.spend_tracking.cold_storage_handler.ColdStorageHandler._get_configured_cold_storage_custom_logger", return_value=None):
result5 = ResponsesSessionHandler._should_check_cold_storage_for_full_payload(proxy_request_with_truncated_pdf)
assert result5 == False, "Should return False when cold storage is not configured, even with truncated content"

View File

@@ -0,0 +1,186 @@
"""
Unit tests for cold storage object key integration.
Tests for the changes to integrate cold storage handling across different components:
1. Add cold_storage_object_key field to StandardLoggingMetadata and SpendLogsMetadata
2. S3Logger generates object key when cold storage is enabled
3. Store object key in SpendLogsMetadata via spend_tracking_utils
4. Session handler uses object key from spend logs metadata
5. S3Logger supports retrieval using provided object key
"""
import json
from datetime import datetime, timezone
from typing import Optional
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from litellm.integrations.s3_v2 import S3Logger
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
from litellm.proxy.spend_tracking.cold_storage_handler import ColdStorageHandler
from litellm.proxy.spend_tracking.spend_tracking_utils import _get_spend_logs_metadata
from litellm.responses.litellm_completion_transformation.session_handler import (
ResponsesSessionHandler,
)
from litellm.types.utils import StandardLoggingMetadata, StandardLoggingPayload
class TestColdStorageObjectKeyIntegration:
"""Test suite for cold storage object key integration."""
def test_standard_logging_metadata_has_cold_storage_object_key_field(self):
"""
Test: Add cold_storage_object_key field to StandardLoggingMetadata.
This test verifies that the StandardLoggingMetadata TypedDict has the
cold_storage_object_key field for storing S3/GCS object keys.
"""
from litellm.types.utils import StandardLoggingMetadata
# Create a StandardLoggingMetadata instance with cold_storage_object_key
metadata = StandardLoggingMetadata(
user_api_key_hash="test_hash",
cold_storage_object_key="test/path/to/object.json"
)
# Verify the field can be set and accessed
assert metadata.get("cold_storage_object_key") == "test/path/to/object.json"
assert "cold_storage_object_key" in StandardLoggingMetadata.__annotations__
def test_spend_logs_metadata_has_cold_storage_object_key_field(self):
"""
Test: Add cold_storage_object_key field to SpendLogsMetadata.
This test verifies that the SpendLogsMetadata TypedDict has the
cold_storage_object_key field for storing S3/GCS object keys.
"""
# Create a SpendLogsMetadata instance with cold_storage_object_key
metadata = SpendLogsMetadata(
user_api_key="test_key",
cold_storage_object_key="test/path/to/object.json"
)
# Verify the field can be set and accessed
assert metadata.get("cold_storage_object_key") == "test/path/to/object.json"
# Verify it's part of the SpendLogsMetadata annotations
assert "cold_storage_object_key" in SpendLogsMetadata.__annotations__
def test_spend_tracking_utils_stores_object_key_in_metadata(self):
"""
Test: Store object key in SpendLogsMetadata via spend_tracking_utils.
This test verifies that the _get_spend_logs_metadata function extracts
the cold_storage_object_key from StandardLoggingPayload and stores it
in SpendLogsMetadata.
"""
# Create test data
metadata = {
"user_api_key": "test_key",
"user_api_key_team_id": "test_team"
}
# Call the function
result = _get_spend_logs_metadata(
metadata=metadata,
cold_storage_object_key="test/path/to/object.json"
)
# Verify the object key is stored in the result
assert result.get("cold_storage_object_key") == "test/path/to/object.json"
def test_session_handler_extracts_object_key_from_spend_log(self):
"""
Test: Session handler extracts object key from spend logs metadata.
This test verifies that the ResponsesSessionHandler can extract the
cold_storage_object_key from spend log metadata.
"""
# Create test spend log
spend_log = {
"request_id": "test_request_id",
"metadata": json.dumps({
"cold_storage_object_key": "test/path/to/object.json",
"user_api_key": "test_key"
})
}
# Test the extraction method
object_key = ResponsesSessionHandler._get_cold_storage_object_key_from_spend_log(spend_log)
assert object_key == "test/path/to/object.json"
def test_session_handler_handles_dict_metadata_in_spend_log(self):
"""
Test: Session handler handles dict metadata in spend log.
This test verifies that the method works when metadata is already a dict.
"""
# Create test spend log with dict metadata
spend_log = {
"request_id": "test_request_id",
"metadata": {
"cold_storage_object_key": "test/path/to/object.json",
"user_api_key": "test_key"
}
}
# Test the extraction method
object_key = ResponsesSessionHandler._get_cold_storage_object_key_from_spend_log(spend_log)
assert object_key == "test/path/to/object.json"
@pytest.mark.asyncio
async def test_cold_storage_handler_supports_object_key_retrieval(self):
"""
Test: ColdStorageHandler supports object key retrieval.
This test verifies that the ColdStorageHandler has the new method
for retrieving objects using object keys directly.
"""
handler = ColdStorageHandler()
# Mock the custom logger
mock_logger = AsyncMock()
mock_logger.get_proxy_server_request_from_cold_storage_with_object_key = AsyncMock(
return_value={"test": "data"}
)
with patch.object(handler, '_select_custom_logger_for_cold_storage', return_value="s3_v2"), \
patch('litellm.logging_callback_manager.get_active_custom_logger_for_callback_name', return_value=mock_logger):
result = await handler.get_proxy_server_request_from_cold_storage_with_object_key(
object_key="test/path/to/object.json"
)
assert result == {"test": "data"}
mock_logger.get_proxy_server_request_from_cold_storage_with_object_key.assert_called_once_with(
object_key="test/path/to/object.json"
)
@pytest.mark.asyncio
@patch('asyncio.create_task') # Mock asyncio.create_task to avoid event loop issues
async def test_s3_logger_supports_object_key_retrieval(self, mock_create_task):
"""
Test: S3Logger supports retrieval using provided object key.
This test verifies that the S3Logger can retrieve objects using
the object key directly without generating it from request_id and start_time.
"""
# Create S3Logger instance
s3_logger = S3Logger(s3_bucket_name="test-bucket")
# Mock the _download_object_from_s3 method
with patch.object(s3_logger, '_download_object_from_s3', return_value={"test": "data"}) as mock_download:
result = await s3_logger.get_proxy_server_request_from_cold_storage_with_object_key(
object_key="test/path/to/object.json"
)
assert result == {"test": "data"}
mock_download.assert_called_once_with("test/path/to/object.json")

View File

@@ -0,0 +1,203 @@
import base64
import json
import os
import sys
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
from litellm.responses.utils import ResponseAPILoggingUtils, ResponsesAPIRequestUtils
from litellm.types.llms.openai import ResponsesAPIOptionalRequestParams
from litellm.types.utils import Usage
class TestResponsesAPIRequestUtils:
def test_get_optional_params_responses_api(self):
"""Test that optional parameters are correctly processed for responses API"""
# Setup
model = "gpt-4o"
config = OpenAIResponsesAPIConfig()
optional_params = ResponsesAPIOptionalRequestParams(
{
"temperature": 0.7,
"max_output_tokens": 100,
"prompt": {"id": "pmpt_123"},
}
)
# Execute
result = ResponsesAPIRequestUtils.get_optional_params_responses_api(
model=model,
responses_api_provider_config=config,
response_api_optional_params=optional_params,
)
# Assert
assert result == optional_params
assert "temperature" in result
assert result["temperature"] == 0.7
assert "max_output_tokens" in result
assert result["max_output_tokens"] == 100
assert "prompt" in result
assert result["prompt"] == {"id": "pmpt_123"}
def test_get_optional_params_responses_api_unsupported_param(self):
"""Test that unsupported parameters raise an error"""
# Setup
model = "gpt-4o"
config = OpenAIResponsesAPIConfig()
optional_params = ResponsesAPIOptionalRequestParams(
{"temperature": 0.7, "unsupported_param": "value"}
)
# Execute and Assert
with pytest.raises(litellm.UnsupportedParamsError) as excinfo:
ResponsesAPIRequestUtils.get_optional_params_responses_api(
model=model,
responses_api_provider_config=config,
response_api_optional_params=optional_params,
)
assert "unsupported_param" in str(excinfo.value)
assert model in str(excinfo.value)
def test_get_requested_response_api_optional_param(self):
"""Test filtering parameters to only include those in ResponsesAPIOptionalRequestParams"""
# Setup
params = {
"temperature": 0.7,
"max_output_tokens": 100,
"prompt": {"id": "pmpt_456"},
"invalid_param": "value",
"model": "gpt-4o", # This is not in ResponsesAPIOptionalRequestParams
}
# Execute
result = ResponsesAPIRequestUtils.get_requested_response_api_optional_param(
params
)
# Assert
assert "temperature" in result
assert "max_output_tokens" in result
assert "invalid_param" not in result
assert "model" not in result
assert result["temperature"] == 0.7
assert result["max_output_tokens"] == 100
assert result["prompt"] == {"id": "pmpt_456"}
def test_decode_previous_response_id_to_original_previous_response_id(self):
"""Test decoding a LiteLLM encoded previous_response_id to the original previous_response_id"""
# Setup
test_provider = "openai"
test_model_id = "gpt-4o"
original_response_id = "resp_abc123"
# Use the helper method to build an encoded response ID
encoded_id = ResponsesAPIRequestUtils._build_responses_api_response_id(
custom_llm_provider=test_provider,
model_id=test_model_id,
response_id=original_response_id,
)
# Execute
result = ResponsesAPIRequestUtils.decode_previous_response_id_to_original_previous_response_id(
encoded_id
)
# Assert
assert result == original_response_id
# Test with a non-encoded ID
plain_id = "resp_xyz789"
result_plain = ResponsesAPIRequestUtils.decode_previous_response_id_to_original_previous_response_id(
plain_id
)
assert result_plain == plain_id
def test_update_responses_api_response_id_with_model_id_handles_dict(self):
"""Ensure _update_responses_api_response_id_with_model_id works with dict input"""
responses_api_response = {"id": "resp_abc123"}
litellm_metadata = {"model_info": {"id": "gpt-4o"}}
updated = ResponsesAPIRequestUtils._update_responses_api_response_id_with_model_id(
responses_api_response=responses_api_response,
custom_llm_provider="openai",
litellm_metadata=litellm_metadata,
)
assert updated["id"] != "resp_abc123"
decoded = ResponsesAPIRequestUtils._decode_responses_api_response_id(updated["id"])
assert decoded.get("response_id") == "resp_abc123"
assert decoded.get("model_id") == "gpt-4o"
assert decoded.get("custom_llm_provider") == "openai"
class TestResponseAPILoggingUtils:
def test_is_response_api_usage_true(self):
"""Test identification of Response API usage format"""
# Setup
usage = {"input_tokens": 10, "output_tokens": 20}
# Execute
result = ResponseAPILoggingUtils._is_response_api_usage(usage)
# Assert
assert result is True
def test_is_response_api_usage_false(self):
"""Test identification of non-Response API usage format"""
# Setup
usage = {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}
# Execute
result = ResponseAPILoggingUtils._is_response_api_usage(usage)
# Assert
assert result is False
def test_transform_response_api_usage_to_chat_usage(self):
"""Test transformation from Response API usage to Chat usage format"""
# Setup
usage = {
"input_tokens": 10,
"output_tokens": 20,
"total_tokens": 30,
"output_tokens_details": {"reasoning_tokens": 5},
}
# Execute
result = ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
usage
)
# Assert
assert isinstance(result, Usage)
assert result.prompt_tokens == 10
assert result.completion_tokens == 20
assert result.total_tokens == 30
def test_transform_response_api_usage_with_none_values(self):
"""Test transformation handles None values properly"""
# Setup
usage = {
"input_tokens": 0, # Changed from None to 0
"output_tokens": 20,
"total_tokens": 20,
"output_tokens_details": {"reasoning_tokens": 5},
}
# Execute
result = ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
usage
)
# Assert
assert result.prompt_tokens == 0
assert result.completion_tokens == 20
assert result.total_tokens == 20