Added LiteLLM to the stack
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
Test reasoning content preservation in Responses API transformation
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from litellm.types.utils import ModelResponseStream, StreamingChoices, Delta
|
||||
from litellm.responses.litellm_completion_transformation.streaming_iterator import (
|
||||
LiteLLMCompletionStreamingIterator,
|
||||
)
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse, Choices, Message
|
||||
|
||||
|
||||
class TestReasoningContentStreaming:
|
||||
"""Test reasoning content preservation during streaming"""
|
||||
|
||||
def test_reasoning_content_in_delta(self):
|
||||
"""Test that reasoning content is preserved in streaming deltas"""
|
||||
# Setup
|
||||
chunk = ModelResponseStream(
|
||||
id="test-id",
|
||||
created=1234567890,
|
||||
model="test-model",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(
|
||||
content="",
|
||||
role="assistant",
|
||||
reasoning_content="Let me think about this problem...",
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
mock_stream = AsyncMock()
|
||||
|
||||
iterator = LiteLLMCompletionStreamingIterator(
|
||||
litellm_custom_stream_wrapper=mock_stream,
|
||||
request_input="Test input",
|
||||
responses_api_request={},
|
||||
)
|
||||
|
||||
# Execute
|
||||
transformed_chunk = (
|
||||
iterator._transform_chat_completion_chunk_to_response_api_chunk(chunk)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert transformed_chunk.delta == "Let me think about this problem..."
|
||||
assert transformed_chunk.type == "response.reasoning_summary_text.delta"
|
||||
|
||||
def test_mixed_content_and_reasoning(self):
|
||||
"""Test handling of both content and reasoning content"""
|
||||
# Setup
|
||||
chunk = ModelResponseStream(
|
||||
id="test-id",
|
||||
created=1234567890,
|
||||
model="test-model",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(
|
||||
content="Here is the answer",
|
||||
role="assistant",
|
||||
reasoning_content="First, let me analyze...",
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
mock_stream = AsyncMock()
|
||||
iterator = LiteLLMCompletionStreamingIterator(
|
||||
litellm_custom_stream_wrapper=mock_stream,
|
||||
request_input="Test input",
|
||||
responses_api_request={},
|
||||
)
|
||||
|
||||
# Execute
|
||||
transformed_chunk = (
|
||||
iterator._transform_chat_completion_chunk_to_response_api_chunk(chunk)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert transformed_chunk.delta == "First, let me analyze..."
|
||||
assert transformed_chunk.type == "response.reasoning_summary_text.delta"
|
||||
|
||||
def test_no_reasoning_content(self):
|
||||
"""Test handling when no reasoning content is present"""
|
||||
# Setup
|
||||
chunk = ModelResponseStream(
|
||||
id="test-id",
|
||||
created=1234567890,
|
||||
model="test-model",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(
|
||||
content="Regular content only",
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
mock_stream = AsyncMock()
|
||||
iterator = LiteLLMCompletionStreamingIterator(
|
||||
litellm_custom_stream_wrapper=mock_stream,
|
||||
request_input="Test input",
|
||||
responses_api_request={},
|
||||
)
|
||||
|
||||
# Execute
|
||||
transformed_chunk = (
|
||||
iterator._transform_chat_completion_chunk_to_response_api_chunk(chunk)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert transformed_chunk.delta == "Regular content only"
|
||||
assert transformed_chunk.type == "response.output_text.delta"
|
||||
|
||||
|
||||
class TestReasoningContentFinalResponse:
|
||||
"""Test reasoning content preservation in final response transformation"""
|
||||
|
||||
def test_reasoning_content_in_final_response(self):
|
||||
"""Test that reasoning content is included in final response"""
|
||||
# Setup
|
||||
response = ModelResponse(
|
||||
id="test-id",
|
||||
created=1234567890,
|
||||
model="test-model",
|
||||
object="chat.completion",
|
||||
choices=[
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=Message(
|
||||
content="Here is my answer",
|
||||
role="assistant",
|
||||
reasoning_content="Let me think step by step about this problem...",
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Execute
|
||||
responses_api_response = LiteLLMCompletionResponsesConfig.transform_chat_completion_response_to_responses_api_response(
|
||||
request_input="Test input",
|
||||
responses_api_request={},
|
||||
chat_completion_response=response,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert hasattr(responses_api_response, "output")
|
||||
assert len(responses_api_response.output) > 0
|
||||
|
||||
reasoning_items = [
|
||||
item for item in responses_api_response.output if item.type == "reasoning"
|
||||
]
|
||||
assert len(reasoning_items) > 0, "No reasoning item found in output"
|
||||
|
||||
reasoning_item = reasoning_items[0]
|
||||
assert (
|
||||
reasoning_item.content[0].text
|
||||
== "Let me think step by step about this problem..."
|
||||
)
|
||||
|
||||
def test_no_reasoning_content_in_response(self):
|
||||
"""Test handling when no reasoning content in response"""
|
||||
# Setup
|
||||
response = ModelResponse(
|
||||
id="test-id",
|
||||
created=1234567890,
|
||||
model="test-model",
|
||||
object="chat.completion",
|
||||
choices=[
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=Message(
|
||||
content="Simple answer",
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Execute
|
||||
responses_api_response = LiteLLMCompletionResponsesConfig.transform_chat_completion_response_to_responses_api_response(
|
||||
request_input="Test input",
|
||||
responses_api_request={},
|
||||
chat_completion_response=response,
|
||||
)
|
||||
|
||||
# Assert
|
||||
reasoning_items = [
|
||||
item for item in responses_api_response.output if item.type == "reasoning"
|
||||
]
|
||||
assert (
|
||||
len(reasoning_items) == 0
|
||||
), "Should have no reasoning items when no reasoning content present"
|
||||
|
||||
def test_multiple_choices_with_reasoning(self):
|
||||
"""Test handling multiple choices, first with reasoning content"""
|
||||
# Setup
|
||||
response = ModelResponse(
|
||||
id="test-id",
|
||||
created=1234567890,
|
||||
model="test-model",
|
||||
object="chat.completion",
|
||||
choices=[
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=Message(
|
||||
content="First answer",
|
||||
role="assistant",
|
||||
reasoning_content="Reasoning for first answer",
|
||||
),
|
||||
),
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=1,
|
||||
message=Message(
|
||||
content="Second answer",
|
||||
role="assistant",
|
||||
reasoning_content="Reasoning for second answer",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Execute
|
||||
responses_api_response = LiteLLMCompletionResponsesConfig.transform_chat_completion_response_to_responses_api_response(
|
||||
request_input="Test input",
|
||||
responses_api_request={},
|
||||
chat_completion_response=response,
|
||||
)
|
||||
|
||||
# Assert
|
||||
reasoning_items = [
|
||||
item for item in responses_api_response.output if item.type == "reasoning"
|
||||
]
|
||||
assert len(reasoning_items) == 1, "Should have exactly one reasoning item"
|
||||
assert reasoning_items[0].content[0].text == "Reasoning for first answer"
|
||||
|
||||
|
||||
def test_streaming_chunk_id_raw():
|
||||
"""Test that streaming chunk IDs are raw (not encoded) to match OpenAI format"""
|
||||
chunk = ModelResponseStream(
|
||||
id="chunk-123",
|
||||
created=1234567890,
|
||||
model="test-model",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(content="Hello", role="assistant"),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
iterator = LiteLLMCompletionStreamingIterator(
|
||||
litellm_custom_stream_wrapper=AsyncMock(),
|
||||
request_input="Test input",
|
||||
responses_api_request={},
|
||||
custom_llm_provider="openai",
|
||||
litellm_metadata={"model_info": {"id": "gpt-4"}},
|
||||
)
|
||||
|
||||
result = iterator._transform_chat_completion_chunk_to_response_api_chunk(chunk)
|
||||
|
||||
# Streaming chunk IDs should be raw (like OpenAI's msg_xxx format)
|
||||
assert result.item_id == "chunk-123" # Should be raw, not encoded
|
||||
assert not result.item_id.startswith("resp_") # Should NOT have resp_ prefix
|
@@ -0,0 +1,366 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.responses.litellm_completion_transformation import session_handler
|
||||
from litellm.responses.litellm_completion_transformation.session_handler import (
|
||||
ResponsesSessionHandler,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_chat_completion_message_history_for_previous_response_id():
|
||||
"""
|
||||
Test get_chat_completion_message_history_for_previous_response_id with mock data
|
||||
"""
|
||||
# Mock data based on the provided spend logs (simplified version)
|
||||
mock_spend_logs = [
|
||||
{
|
||||
"request_id": "chatcmpl-935b8dad-fdc2-466e-a8ca-e26e5a8a21bb",
|
||||
"call_type": "aresponses",
|
||||
"api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
|
||||
"spend": 0.004803,
|
||||
"total_tokens": 329,
|
||||
"prompt_tokens": 11,
|
||||
"completion_tokens": 318,
|
||||
"startTime": "2025-05-30T03:17:06.703+00:00",
|
||||
"endTime": "2025-05-30T03:17:11.894+00:00",
|
||||
"model": "claude-3-5-sonnet-latest",
|
||||
"session_id": "a96757c4-c6dc-4c76-b37e-e7dfa526b701",
|
||||
"proxy_server_request": {
|
||||
"input": "who is Michael Jordan",
|
||||
"model": "anthropic/claude-3-5-sonnet-latest",
|
||||
},
|
||||
"response": {
|
||||
"id": "chatcmpl-935b8dad-fdc2-466e-a8ca-e26e5a8a21bb",
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"object": "chat.completion",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Michael Jordan (born February 17, 1963) is widely considered the greatest basketball player of all time. Here are some key points about him...",
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"created": 1748575031,
|
||||
"usage": {
|
||||
"total_tokens": 329,
|
||||
"prompt_tokens": 11,
|
||||
"completion_tokens": 318,
|
||||
},
|
||||
},
|
||||
"status": "success",
|
||||
},
|
||||
{
|
||||
"request_id": "chatcmpl-370760c9-39fa-4db7-b034-d1f8d933c935",
|
||||
"call_type": "aresponses",
|
||||
"api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
|
||||
"spend": 0.010437,
|
||||
"total_tokens": 967,
|
||||
"prompt_tokens": 339,
|
||||
"completion_tokens": 628,
|
||||
"startTime": "2025-05-30T03:17:28.600+00:00",
|
||||
"endTime": "2025-05-30T03:17:39.921+00:00",
|
||||
"model": "claude-3-5-sonnet-latest",
|
||||
"session_id": "a96757c4-c6dc-4c76-b37e-e7dfa526b701",
|
||||
"proxy_server_request": {
|
||||
"input": "can you tell me more about him",
|
||||
"model": "anthropic/claude-3-5-sonnet-latest",
|
||||
"previous_response_id": "resp_bGl0ZWxsbTpjdXN0b21fbGxtX3Byb3ZpZGVyOmFudGhyb3BpYzttb2RlbF9pZDplMGYzMDJhMTQxMmU3ODQ3MGViYjI4Y2JlZDAxZmZmNWY4OGMwZDMzMWM2NjdlOWYyYmE0YjQxM2M2ZmJkMjgyO3Jlc3BvbnNlX2lkOmNoYXRjbXBsLTkzNWI4ZGFkLWZkYzItNDY2ZS1hOGNhLWUyNmU1YThhMjFiYg==",
|
||||
},
|
||||
"response": {
|
||||
"id": "chatcmpl-370760c9-39fa-4db7-b034-d1f8d933c935",
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"object": "chat.completion",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Here's more detailed information about Michael Jordan...",
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"created": 1748575059,
|
||||
"usage": {
|
||||
"total_tokens": 967,
|
||||
"prompt_tokens": 339,
|
||||
"completion_tokens": 628,
|
||||
},
|
||||
},
|
||||
"status": "success",
|
||||
},
|
||||
]
|
||||
|
||||
# Mock the get_all_spend_logs_for_previous_response_id method
|
||||
with patch.object(
|
||||
ResponsesSessionHandler,
|
||||
"get_all_spend_logs_for_previous_response_id",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_spend_logs:
|
||||
mock_get_spend_logs.return_value = mock_spend_logs
|
||||
|
||||
# Test the function
|
||||
previous_response_id = "chatcmpl-935b8dad-fdc2-466e-a8ca-e26e5a8a21bb"
|
||||
result = await ResponsesSessionHandler.get_chat_completion_message_history_for_previous_response_id(
|
||||
previous_response_id
|
||||
)
|
||||
|
||||
# Verify the mock was called with correct parameters
|
||||
mock_get_spend_logs.assert_called_once_with(previous_response_id)
|
||||
|
||||
# Verify the returned ChatCompletionSession structure
|
||||
assert "messages" in result
|
||||
assert "litellm_session_id" in result
|
||||
|
||||
# Verify session_id is extracted correctly
|
||||
assert result["litellm_session_id"] == "a96757c4-c6dc-4c76-b37e-e7dfa526b701"
|
||||
|
||||
# Verify messages structure
|
||||
messages = result["messages"]
|
||||
assert len(messages) == 4 # 2 user messages + 2 assistant messages
|
||||
|
||||
# Check the message sequence
|
||||
# First user message
|
||||
assert messages[0].get("role") == "user"
|
||||
assert messages[0].get("content") == "who is Michael Jordan"
|
||||
|
||||
# First assistant response
|
||||
assert messages[1].get("role") == "assistant"
|
||||
content_1 = messages[1].get("content", "")
|
||||
if isinstance(content_1, str):
|
||||
assert "Michael Jordan" in content_1
|
||||
assert content_1.startswith("Michael Jordan (born February 17, 1963)")
|
||||
|
||||
# Second user message
|
||||
assert messages[2].get("role") == "user"
|
||||
assert messages[2].get("content") == "can you tell me more about him"
|
||||
|
||||
# Second assistant response
|
||||
assert messages[3].get("role") == "assistant"
|
||||
content_3 = messages[3].get("content", "")
|
||||
if isinstance(content_3, str):
|
||||
assert "Here's more detailed information about Michael Jordan" in content_3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_chat_completion_message_history_empty_spend_logs():
|
||||
"""
|
||||
Test get_chat_completion_message_history_for_previous_response_id with empty spend logs
|
||||
"""
|
||||
with patch.object(
|
||||
ResponsesSessionHandler,
|
||||
"get_all_spend_logs_for_previous_response_id",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_spend_logs:
|
||||
mock_get_spend_logs.return_value = []
|
||||
|
||||
previous_response_id = "non-existent-id"
|
||||
result = await ResponsesSessionHandler.get_chat_completion_message_history_for_previous_response_id(
|
||||
previous_response_id
|
||||
)
|
||||
|
||||
# Verify empty result structure
|
||||
assert result.get("messages") == []
|
||||
assert result.get("litellm_session_id") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_e2e_cold_storage_successful_retrieval():
|
||||
"""
|
||||
Test end-to-end cold storage functionality with successful retrieval of full proxy request from cold storage.
|
||||
"""
|
||||
# Mock spend logs with cold storage object key in metadata
|
||||
mock_spend_logs = [
|
||||
{
|
||||
"request_id": "chatcmpl-test-123",
|
||||
"session_id": "session-456",
|
||||
"metadata": '{"cold_storage_object_key": "s3://test-bucket/requests/session_456_req1.json"}',
|
||||
"proxy_server_request": '{"litellm_truncated": true}', # Truncated payload
|
||||
"response": {
|
||||
"id": "chatcmpl-test-123",
|
||||
"object": "chat.completion",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I am an AI assistant."
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Full proxy request data from cold storage
|
||||
full_proxy_request = {
|
||||
"input": "Hello, who are you?",
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello, who are you?"}]
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
ResponsesSessionHandler,
|
||||
"get_all_spend_logs_for_previous_response_id",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_spend_logs, \
|
||||
patch.object(session_handler, "COLD_STORAGE_HANDLER") as mock_cold_storage, \
|
||||
patch("litellm.proxy.spend_tracking.cold_storage_handler.ColdStorageHandler._get_configured_cold_storage_custom_logger", return_value="s3"):
|
||||
|
||||
# Setup mocks
|
||||
mock_get_spend_logs.return_value = mock_spend_logs
|
||||
mock_cold_storage.get_proxy_server_request_from_cold_storage_with_object_key = AsyncMock(return_value=full_proxy_request)
|
||||
|
||||
# Call the main function
|
||||
result = await ResponsesSessionHandler.get_chat_completion_message_history_for_previous_response_id(
|
||||
"chatcmpl-test-123"
|
||||
)
|
||||
|
||||
# Verify cold storage was called with correct object key
|
||||
mock_cold_storage.get_proxy_server_request_from_cold_storage_with_object_key.assert_called_once_with(
|
||||
object_key="s3://test-bucket/requests/session_456_req1.json"
|
||||
)
|
||||
|
||||
# Verify result structure
|
||||
assert result.get("litellm_session_id") == "session-456"
|
||||
assert len(result.get("messages", [])) >= 1 # At least the assistant response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_e2e_cold_storage_fallback_to_truncated_payload():
|
||||
"""
|
||||
Test end-to-end cold storage functionality when object key is missing, falling back to truncated payload.
|
||||
"""
|
||||
# Mock spend logs without cold storage object key
|
||||
mock_spend_logs = [
|
||||
{
|
||||
"request_id": "chatcmpl-test-789",
|
||||
"session_id": "session-999",
|
||||
"metadata": '{"user_api_key": "test-key"}', # No cold storage object key
|
||||
"proxy_server_request": '{"input": "Truncated message", "model": "gpt-4"}', # Regular payload
|
||||
"response": {
|
||||
"id": "chatcmpl-test-789",
|
||||
"object": "chat.completion",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "This is a response."
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
ResponsesSessionHandler,
|
||||
"get_all_spend_logs_for_previous_response_id",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_spend_logs, \
|
||||
patch.object(session_handler, "COLD_STORAGE_HANDLER") as mock_cold_storage:
|
||||
|
||||
# Setup mocks
|
||||
mock_get_spend_logs.return_value = mock_spend_logs
|
||||
|
||||
# Call the main function
|
||||
result = await ResponsesSessionHandler.get_chat_completion_message_history_for_previous_response_id(
|
||||
"chatcmpl-test-789"
|
||||
)
|
||||
|
||||
# Verify cold storage was NOT called since no object key in metadata
|
||||
mock_cold_storage.get_proxy_server_request_from_cold_storage_with_object_key.assert_not_called()
|
||||
|
||||
# Verify result structure
|
||||
assert result.get("litellm_session_id") == "session-999"
|
||||
assert len(result.get("messages", [])) >= 1 # At least the assistant response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_check_cold_storage_for_full_payload():
|
||||
"""
|
||||
Test _should_check_cold_storage_for_full_payload returns True for proxy server requests with truncated content
|
||||
"""
|
||||
|
||||
# Test case 1: Proxy server request with truncated PDF content (should return True)
|
||||
proxy_request_with_truncated_pdf = {
|
||||
"input": [
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"text": "what was datadogs largest source of operating cash ? quote the section you saw ",
|
||||
"type": "input_text"
|
||||
},
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": "data:application/pdf;base64,JVBERi0xLjcKJYGBgYEKCjcgMCBvYmoKPDwKL0ZpbHRlciAvRmxhdGVEZWNvZGUKL0xlbmd0aCA1NjcxCj4+CnN0cmVhbQp4nO1dW4/cthV+31+h5wKVeb8AhoG9Bn0I0DYL9NlInQBFHKSpA+Tnl5qRNNRIn8ij4WpnbdqAsRaX90Oe23cOWyH94U/Dwt+/ttF/neKt59675sfPN/+9Ubp1MvwRjfAtN92fRkgn2+5jI5Xyre9++fdPN//6S/NrqCFax4XqvnVtn/631FLogjfd339+1xx/+P3nm3ffyebn/92ww2Bc46zRrGv/p5vWMOmb+N9Qb/4xtOEazn2oH3rjfV3fDTj+t6s7+zjU5XFdFxo9fPs8/CiaX26cYmc/svDjhlF+Pv7QNdT30/9wbI8dFjK0cfzhUO8wPjaOr/Eq/v/d8827vzfv37/7/v5vD6HKhw93D/c3755UI3jYuOb5p7Dsh53nYQtZqyUXugn71Dx/vnnPmHQfmuf/3HDdKhY2z8jwq8//broSjkrE/aHEtZIxZhQ/VbHHKqoVQhsv7KmKgyUWdcPkoUSHaYgwmmhk5liFt85w45WcDUC0RgntpTp1o44lMhCpl95eNOZ+AI/f3988Pp9tAV/dAu5VKz0Ls+SB0vstgNNZWTUDtw1vqC65oahKctUW5gl3etg1EnXSCWplzHewG7gDoq9jWu28DYuzPB3NucmZzm3UmvFcLI/aMpcxT0w1AvUi2aFEsJZLprxMF0xIQzmXQQArRwA2Fo/YCm7P13LxdIrodIa7QIojL5zekqblloeuwkiGI3o7jk+zMEJ9vqW+NWFDtTDnU7KtCpet1WZGHqyVxjLNxPlcdcudsU648xnNOxmIY95Wf3JduAidF3q+bgtVUC/s6VAgZ/TUE9q8oD+DCwM2qA/UFD/eWqY1RrFQ61QgUYFDBRYV3GOKkSPFaEQxQqqWWRcGzZ3u... (litellm_truncated 1197576 chars)"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"model": "anthropic/claude-3-7-sonnet-20250219",
|
||||
"stream": True,
|
||||
"litellm_trace_id": "16b86861-c120-4ecb-865b-4d2238bfd8f0"
|
||||
}
|
||||
|
||||
# Test case 2: Regular proxy request without truncation (should return False)
|
||||
proxy_request_regular = {
|
||||
"input": [
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": "Hello, this is a regular message"
|
||||
}
|
||||
],
|
||||
"model": "anthropic/claude-3-7-sonnet-20250219",
|
||||
"stream": True
|
||||
}
|
||||
|
||||
# Test case 3: Empty request (should return True)
|
||||
proxy_request_empty = {}
|
||||
|
||||
# Test case 4: None request (should return True)
|
||||
proxy_request_none = None
|
||||
|
||||
with patch("litellm.proxy.spend_tracking.cold_storage_handler.ColdStorageHandler._get_configured_cold_storage_custom_logger", return_value="s3"):
|
||||
# Test case 1: Should return True for truncated content
|
||||
result1 = ResponsesSessionHandler._should_check_cold_storage_for_full_payload(proxy_request_with_truncated_pdf)
|
||||
assert result1 == True, "Should return True for proxy request with truncated PDF content"
|
||||
|
||||
# Test case 2: Should return False for regular content
|
||||
result2 = ResponsesSessionHandler._should_check_cold_storage_for_full_payload(proxy_request_regular)
|
||||
assert result2 == False, "Should return False for regular proxy request without truncation"
|
||||
|
||||
# Test case 3: Should return True for empty request
|
||||
result3 = ResponsesSessionHandler._should_check_cold_storage_for_full_payload(proxy_request_empty)
|
||||
assert result3 == True, "Should return True for empty proxy request"
|
||||
|
||||
# Test case 4: Should return True for None request
|
||||
result4 = ResponsesSessionHandler._should_check_cold_storage_for_full_payload(proxy_request_none)
|
||||
assert result4 == True, "Should return True for None proxy request"
|
||||
|
||||
# Test case 5: Should return False when cold storage is not configured
|
||||
with patch("litellm.proxy.spend_tracking.cold_storage_handler.ColdStorageHandler._get_configured_cold_storage_custom_logger", return_value=None):
|
||||
result5 = ResponsesSessionHandler._should_check_cold_storage_for_full_payload(proxy_request_with_truncated_pdf)
|
||||
assert result5 == False, "Should return False when cold storage is not configured, even with truncated content"
|
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
Unit tests for cold storage object key integration.
|
||||
|
||||
Tests for the changes to integrate cold storage handling across different components:
|
||||
1. Add cold_storage_object_key field to StandardLoggingMetadata and SpendLogsMetadata
|
||||
2. S3Logger generates object key when cold storage is enabled
|
||||
3. Store object key in SpendLogsMetadata via spend_tracking_utils
|
||||
4. Session handler uses object key from spend logs metadata
|
||||
5. S3Logger supports retrieval using provided object key
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.integrations.s3_v2 import S3Logger
|
||||
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
||||
from litellm.proxy.spend_tracking.cold_storage_handler import ColdStorageHandler
|
||||
from litellm.proxy.spend_tracking.spend_tracking_utils import _get_spend_logs_metadata
|
||||
from litellm.responses.litellm_completion_transformation.session_handler import (
|
||||
ResponsesSessionHandler,
|
||||
)
|
||||
from litellm.types.utils import StandardLoggingMetadata, StandardLoggingPayload
|
||||
|
||||
|
||||
class TestColdStorageObjectKeyIntegration:
|
||||
"""Test suite for cold storage object key integration."""
|
||||
|
||||
def test_standard_logging_metadata_has_cold_storage_object_key_field(self):
|
||||
"""
|
||||
Test: Add cold_storage_object_key field to StandardLoggingMetadata.
|
||||
|
||||
This test verifies that the StandardLoggingMetadata TypedDict has the
|
||||
cold_storage_object_key field for storing S3/GCS object keys.
|
||||
"""
|
||||
from litellm.types.utils import StandardLoggingMetadata
|
||||
|
||||
# Create a StandardLoggingMetadata instance with cold_storage_object_key
|
||||
metadata = StandardLoggingMetadata(
|
||||
user_api_key_hash="test_hash",
|
||||
cold_storage_object_key="test/path/to/object.json"
|
||||
)
|
||||
|
||||
# Verify the field can be set and accessed
|
||||
assert metadata.get("cold_storage_object_key") == "test/path/to/object.json"
|
||||
|
||||
assert "cold_storage_object_key" in StandardLoggingMetadata.__annotations__
|
||||
|
||||
def test_spend_logs_metadata_has_cold_storage_object_key_field(self):
|
||||
"""
|
||||
Test: Add cold_storage_object_key field to SpendLogsMetadata.
|
||||
|
||||
This test verifies that the SpendLogsMetadata TypedDict has the
|
||||
cold_storage_object_key field for storing S3/GCS object keys.
|
||||
"""
|
||||
# Create a SpendLogsMetadata instance with cold_storage_object_key
|
||||
metadata = SpendLogsMetadata(
|
||||
user_api_key="test_key",
|
||||
cold_storage_object_key="test/path/to/object.json"
|
||||
)
|
||||
|
||||
# Verify the field can be set and accessed
|
||||
assert metadata.get("cold_storage_object_key") == "test/path/to/object.json"
|
||||
|
||||
# Verify it's part of the SpendLogsMetadata annotations
|
||||
assert "cold_storage_object_key" in SpendLogsMetadata.__annotations__
|
||||
|
||||
|
||||
def test_spend_tracking_utils_stores_object_key_in_metadata(self):
|
||||
"""
|
||||
Test: Store object key in SpendLogsMetadata via spend_tracking_utils.
|
||||
|
||||
This test verifies that the _get_spend_logs_metadata function extracts
|
||||
the cold_storage_object_key from StandardLoggingPayload and stores it
|
||||
in SpendLogsMetadata.
|
||||
"""
|
||||
# Create test data
|
||||
metadata = {
|
||||
"user_api_key": "test_key",
|
||||
"user_api_key_team_id": "test_team"
|
||||
}
|
||||
|
||||
|
||||
# Call the function
|
||||
result = _get_spend_logs_metadata(
|
||||
metadata=metadata,
|
||||
cold_storage_object_key="test/path/to/object.json"
|
||||
)
|
||||
|
||||
# Verify the object key is stored in the result
|
||||
assert result.get("cold_storage_object_key") == "test/path/to/object.json"
|
||||
|
||||
|
||||
def test_session_handler_extracts_object_key_from_spend_log(self):
|
||||
"""
|
||||
Test: Session handler extracts object key from spend logs metadata.
|
||||
|
||||
This test verifies that the ResponsesSessionHandler can extract the
|
||||
cold_storage_object_key from spend log metadata.
|
||||
"""
|
||||
# Create test spend log
|
||||
spend_log = {
|
||||
"request_id": "test_request_id",
|
||||
"metadata": json.dumps({
|
||||
"cold_storage_object_key": "test/path/to/object.json",
|
||||
"user_api_key": "test_key"
|
||||
})
|
||||
}
|
||||
|
||||
# Test the extraction method
|
||||
object_key = ResponsesSessionHandler._get_cold_storage_object_key_from_spend_log(spend_log)
|
||||
|
||||
assert object_key == "test/path/to/object.json"
|
||||
|
||||
def test_session_handler_handles_dict_metadata_in_spend_log(self):
|
||||
"""
|
||||
Test: Session handler handles dict metadata in spend log.
|
||||
|
||||
This test verifies that the method works when metadata is already a dict.
|
||||
"""
|
||||
# Create test spend log with dict metadata
|
||||
spend_log = {
|
||||
"request_id": "test_request_id",
|
||||
"metadata": {
|
||||
"cold_storage_object_key": "test/path/to/object.json",
|
||||
"user_api_key": "test_key"
|
||||
}
|
||||
}
|
||||
|
||||
# Test the extraction method
|
||||
object_key = ResponsesSessionHandler._get_cold_storage_object_key_from_spend_log(spend_log)
|
||||
|
||||
assert object_key == "test/path/to/object.json"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cold_storage_handler_supports_object_key_retrieval(self):
|
||||
"""
|
||||
Test: ColdStorageHandler supports object key retrieval.
|
||||
|
||||
This test verifies that the ColdStorageHandler has the new method
|
||||
for retrieving objects using object keys directly.
|
||||
"""
|
||||
handler = ColdStorageHandler()
|
||||
|
||||
# Mock the custom logger
|
||||
mock_logger = AsyncMock()
|
||||
mock_logger.get_proxy_server_request_from_cold_storage_with_object_key = AsyncMock(
|
||||
return_value={"test": "data"}
|
||||
)
|
||||
|
||||
with patch.object(handler, '_select_custom_logger_for_cold_storage', return_value="s3_v2"), \
|
||||
patch('litellm.logging_callback_manager.get_active_custom_logger_for_callback_name', return_value=mock_logger):
|
||||
|
||||
result = await handler.get_proxy_server_request_from_cold_storage_with_object_key(
|
||||
object_key="test/path/to/object.json"
|
||||
)
|
||||
|
||||
assert result == {"test": "data"}
|
||||
mock_logger.get_proxy_server_request_from_cold_storage_with_object_key.assert_called_once_with(
|
||||
object_key="test/path/to/object.json"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('asyncio.create_task') # Mock asyncio.create_task to avoid event loop issues
|
||||
async def test_s3_logger_supports_object_key_retrieval(self, mock_create_task):
|
||||
"""
|
||||
Test: S3Logger supports retrieval using provided object key.
|
||||
|
||||
This test verifies that the S3Logger can retrieve objects using
|
||||
the object key directly without generating it from request_id and start_time.
|
||||
"""
|
||||
# Create S3Logger instance
|
||||
s3_logger = S3Logger(s3_bucket_name="test-bucket")
|
||||
|
||||
# Mock the _download_object_from_s3 method
|
||||
with patch.object(s3_logger, '_download_object_from_s3', return_value={"test": "data"}) as mock_download:
|
||||
result = await s3_logger.get_proxy_server_request_from_cold_storage_with_object_key(
|
||||
object_key="test/path/to/object.json"
|
||||
)
|
||||
|
||||
assert result == {"test": "data"}
|
||||
mock_download.assert_called_once_with("test/path/to/object.json")
|
@@ -0,0 +1,203 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
||||
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||||
from litellm.responses.utils import ResponseAPILoggingUtils, ResponsesAPIRequestUtils
|
||||
from litellm.types.llms.openai import ResponsesAPIOptionalRequestParams
|
||||
from litellm.types.utils import Usage
|
||||
|
||||
|
||||
class TestResponsesAPIRequestUtils:
|
||||
def test_get_optional_params_responses_api(self):
|
||||
"""Test that optional parameters are correctly processed for responses API"""
|
||||
# Setup
|
||||
model = "gpt-4o"
|
||||
config = OpenAIResponsesAPIConfig()
|
||||
optional_params = ResponsesAPIOptionalRequestParams(
|
||||
{
|
||||
"temperature": 0.7,
|
||||
"max_output_tokens": 100,
|
||||
"prompt": {"id": "pmpt_123"},
|
||||
}
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = ResponsesAPIRequestUtils.get_optional_params_responses_api(
|
||||
model=model,
|
||||
responses_api_provider_config=config,
|
||||
response_api_optional_params=optional_params,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == optional_params
|
||||
assert "temperature" in result
|
||||
assert result["temperature"] == 0.7
|
||||
assert "max_output_tokens" in result
|
||||
assert result["max_output_tokens"] == 100
|
||||
assert "prompt" in result
|
||||
assert result["prompt"] == {"id": "pmpt_123"}
|
||||
|
||||
def test_get_optional_params_responses_api_unsupported_param(self):
|
||||
"""Test that unsupported parameters raise an error"""
|
||||
# Setup
|
||||
model = "gpt-4o"
|
||||
config = OpenAIResponsesAPIConfig()
|
||||
optional_params = ResponsesAPIOptionalRequestParams(
|
||||
{"temperature": 0.7, "unsupported_param": "value"}
|
||||
)
|
||||
|
||||
# Execute and Assert
|
||||
with pytest.raises(litellm.UnsupportedParamsError) as excinfo:
|
||||
ResponsesAPIRequestUtils.get_optional_params_responses_api(
|
||||
model=model,
|
||||
responses_api_provider_config=config,
|
||||
response_api_optional_params=optional_params,
|
||||
)
|
||||
|
||||
assert "unsupported_param" in str(excinfo.value)
|
||||
assert model in str(excinfo.value)
|
||||
|
||||
def test_get_requested_response_api_optional_param(self):
|
||||
"""Test filtering parameters to only include those in ResponsesAPIOptionalRequestParams"""
|
||||
# Setup
|
||||
params = {
|
||||
"temperature": 0.7,
|
||||
"max_output_tokens": 100,
|
||||
"prompt": {"id": "pmpt_456"},
|
||||
"invalid_param": "value",
|
||||
"model": "gpt-4o", # This is not in ResponsesAPIOptionalRequestParams
|
||||
}
|
||||
|
||||
# Execute
|
||||
result = ResponsesAPIRequestUtils.get_requested_response_api_optional_param(
|
||||
params
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert "temperature" in result
|
||||
assert "max_output_tokens" in result
|
||||
assert "invalid_param" not in result
|
||||
assert "model" not in result
|
||||
assert result["temperature"] == 0.7
|
||||
assert result["max_output_tokens"] == 100
|
||||
assert result["prompt"] == {"id": "pmpt_456"}
|
||||
|
||||
def test_decode_previous_response_id_to_original_previous_response_id(self):
|
||||
"""Test decoding a LiteLLM encoded previous_response_id to the original previous_response_id"""
|
||||
# Setup
|
||||
test_provider = "openai"
|
||||
test_model_id = "gpt-4o"
|
||||
original_response_id = "resp_abc123"
|
||||
|
||||
# Use the helper method to build an encoded response ID
|
||||
encoded_id = ResponsesAPIRequestUtils._build_responses_api_response_id(
|
||||
custom_llm_provider=test_provider,
|
||||
model_id=test_model_id,
|
||||
response_id=original_response_id,
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = ResponsesAPIRequestUtils.decode_previous_response_id_to_original_previous_response_id(
|
||||
encoded_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == original_response_id
|
||||
|
||||
# Test with a non-encoded ID
|
||||
plain_id = "resp_xyz789"
|
||||
result_plain = ResponsesAPIRequestUtils.decode_previous_response_id_to_original_previous_response_id(
|
||||
plain_id
|
||||
)
|
||||
assert result_plain == plain_id
|
||||
|
||||
def test_update_responses_api_response_id_with_model_id_handles_dict(self):
|
||||
"""Ensure _update_responses_api_response_id_with_model_id works with dict input"""
|
||||
responses_api_response = {"id": "resp_abc123"}
|
||||
litellm_metadata = {"model_info": {"id": "gpt-4o"}}
|
||||
updated = ResponsesAPIRequestUtils._update_responses_api_response_id_with_model_id(
|
||||
responses_api_response=responses_api_response,
|
||||
custom_llm_provider="openai",
|
||||
litellm_metadata=litellm_metadata,
|
||||
)
|
||||
assert updated["id"] != "resp_abc123"
|
||||
decoded = ResponsesAPIRequestUtils._decode_responses_api_response_id(updated["id"])
|
||||
assert decoded.get("response_id") == "resp_abc123"
|
||||
assert decoded.get("model_id") == "gpt-4o"
|
||||
assert decoded.get("custom_llm_provider") == "openai"
|
||||
|
||||
|
||||
class TestResponseAPILoggingUtils:
|
||||
def test_is_response_api_usage_true(self):
|
||||
"""Test identification of Response API usage format"""
|
||||
# Setup
|
||||
usage = {"input_tokens": 10, "output_tokens": 20}
|
||||
|
||||
# Execute
|
||||
result = ResponseAPILoggingUtils._is_response_api_usage(usage)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_is_response_api_usage_false(self):
|
||||
"""Test identification of non-Response API usage format"""
|
||||
# Setup
|
||||
usage = {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}
|
||||
|
||||
# Execute
|
||||
result = ResponseAPILoggingUtils._is_response_api_usage(usage)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_transform_response_api_usage_to_chat_usage(self):
|
||||
"""Test transformation from Response API usage to Chat usage format"""
|
||||
# Setup
|
||||
usage = {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
"output_tokens_details": {"reasoning_tokens": 5},
|
||||
}
|
||||
|
||||
# Execute
|
||||
result = ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
|
||||
usage
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, Usage)
|
||||
assert result.prompt_tokens == 10
|
||||
assert result.completion_tokens == 20
|
||||
assert result.total_tokens == 30
|
||||
|
||||
def test_transform_response_api_usage_with_none_values(self):
|
||||
"""Test transformation handles None values properly"""
|
||||
# Setup
|
||||
usage = {
|
||||
"input_tokens": 0, # Changed from None to 0
|
||||
"output_tokens": 20,
|
||||
"total_tokens": 20,
|
||||
"output_tokens_details": {"reasoning_tokens": 5},
|
||||
}
|
||||
|
||||
# Execute
|
||||
result = ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
|
||||
usage
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.prompt_tokens == 0
|
||||
assert result.completion_tokens == 20
|
||||
assert result.total_tokens == 20
|
Reference in New Issue
Block a user