Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,167 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
# Adds the grandparent directory to sys.path to allow importing project modules
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
||||
|
||||
import litellm
|
||||
from litellm.integrations.arize.arize import ArizeLogger
|
||||
from litellm.integrations.opentelemetry import OpenTelemetryConfig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arize_dynamic_params():
|
||||
"""Test that the OpenTelemetry logger uses the correct dynamic headers for each Arize request."""
|
||||
|
||||
# Create ArizeLogger instance
|
||||
arize_logger = ArizeLogger()
|
||||
|
||||
# Capture the get_tracer_to_use_for_request calls
|
||||
tracer_calls = []
|
||||
original_get_tracer = arize_logger.get_tracer_to_use_for_request
|
||||
|
||||
def mock_get_tracer_to_use_for_request(kwargs):
|
||||
# Capture the kwargs to see what dynamic headers are being used
|
||||
tracer_calls.append(kwargs)
|
||||
# Return the default tracer
|
||||
return arize_logger.tracer
|
||||
|
||||
# Mock the get_tracer_to_use_for_request method
|
||||
arize_logger.get_tracer_to_use_for_request = mock_get_tracer_to_use_for_request
|
||||
|
||||
# Set up callbacks
|
||||
litellm.callbacks = [arize_logger]
|
||||
|
||||
# First request with team1 credentials
|
||||
await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi test from arize dynamic config"}],
|
||||
temperature=0.1,
|
||||
mock_response="test_response",
|
||||
arize_api_key="team1_key",
|
||||
arize_space_id="team1_space_id"
|
||||
)
|
||||
|
||||
# Second request with team2 credentials
|
||||
await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi test from arize dynamic config"}],
|
||||
temperature=0.1,
|
||||
mock_response="test_response",
|
||||
arize_api_key="team2_key",
|
||||
arize_space_id="team2_space_id"
|
||||
)
|
||||
|
||||
# Allow some time for async processing
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Assertions
|
||||
print(f"Tracer calls: {len(tracer_calls)}")
|
||||
|
||||
# We should have captured calls for both requests
|
||||
assert len(tracer_calls) >= 2, f"Expected at least 2 tracer calls, got {len(tracer_calls)}"
|
||||
|
||||
# Check that we have the expected dynamic params in the kwargs
|
||||
team1_found = False
|
||||
team2_found = False
|
||||
|
||||
print("args to tracer calls", tracer_calls)
|
||||
|
||||
for call_kwargs in tracer_calls:
|
||||
dynamic_params = call_kwargs.get("standard_callback_dynamic_params", {})
|
||||
if dynamic_params.get("arize_api_key") == "team1_key":
|
||||
team1_found = True
|
||||
assert dynamic_params.get("arize_space_id") == "team1_space_id"
|
||||
elif dynamic_params.get("arize_api_key") == "team2_key":
|
||||
team2_found = True
|
||||
assert dynamic_params.get("arize_space_id") == "team2_space_id"
|
||||
|
||||
# Verify both teams were found
|
||||
assert team1_found, "team1 dynamic params not found"
|
||||
assert team2_found, "team2 dynamic params not found"
|
||||
|
||||
print("✅ All assertions passed - OpenTelemetry logger correctly received dynamic params")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arize_dynamic_headers_in_grpc_requests():
|
||||
"""Test that dynamic Arize params are passed as headers to the gRPC/HTTP exporter."""
|
||||
|
||||
# Track all exporter calls and their headers
|
||||
exporter_headers = []
|
||||
|
||||
def mock_otlp_http_exporter(*args, **kwargs):
|
||||
# Capture the headers passed to the HTTP exporter
|
||||
headers = kwargs.get('headers', {})
|
||||
exporter_headers.append(headers)
|
||||
|
||||
# Return a mock exporter
|
||||
mock_exporter = MagicMock()
|
||||
mock_exporter.export = MagicMock(return_value=None)
|
||||
return mock_exporter
|
||||
|
||||
# Patch the HTTP exporter (Arize uses HTTP by default)
|
||||
with patch('opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter', mock_otlp_http_exporter):
|
||||
|
||||
# Create ArizeLogger with HTTP configuration
|
||||
config = OpenTelemetryConfig(
|
||||
exporter="otlp_http",
|
||||
endpoint="https://otlp.arize.com/v1"
|
||||
)
|
||||
arize_logger = ArizeLogger(config=config)
|
||||
litellm.callbacks = [arize_logger]
|
||||
|
||||
# Request 1: team1 dynamic params
|
||||
await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi from team1"}],
|
||||
mock_response="response1",
|
||||
arize_api_key="team1_api_key",
|
||||
arize_space_id="team1_space_id"
|
||||
)
|
||||
|
||||
# Request 2: team2 dynamic params
|
||||
await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi from team2"}],
|
||||
mock_response="response2",
|
||||
arize_api_key="team2_api_key",
|
||||
arize_space_id="team2_space_id"
|
||||
)
|
||||
|
||||
# Allow time for async processing
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# Assertions
|
||||
print(f"Captured exporter headers: {exporter_headers}")
|
||||
|
||||
# Should have multiple exporter calls (default + dynamic)
|
||||
assert len(exporter_headers) >= 2, f"Expected at least 2 exporter calls, got {len(exporter_headers)}"
|
||||
|
||||
# Find team1 and team2 headers
|
||||
team1_found = False
|
||||
team2_found = False
|
||||
|
||||
for headers in exporter_headers:
|
||||
if headers.get('api_key') == 'team1_api_key' and headers.get('arize-space-id') == 'team1_space_id':
|
||||
team1_found = True
|
||||
print(f"✅ Found team1 headers: {headers}")
|
||||
elif headers.get('api_key') == 'team2_api_key' and headers.get('arize-space-id') == 'team2_space_id':
|
||||
team2_found = True
|
||||
print(f"✅ Found team2 headers: {headers}")
|
||||
|
||||
# Verify both dynamic header sets were used
|
||||
assert team1_found, "team1 dynamic headers not found in exporter calls"
|
||||
assert team2_found, "team2 dynamic headers not found in exporter calls"
|
||||
|
||||
print("✅ Test passed - Dynamic Arize params correctly passed to gRPC/HTTP exporter")
|
||||
|
||||
|
@@ -0,0 +1,46 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from litellm.integrations.arize.arize_phoenix import ArizePhoenixLogger
|
||||
|
||||
|
||||
class TestArizePhoenixConfig(unittest.TestCase):
|
||||
@patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"PHOENIX_API_KEY": "test_api_key",
|
||||
"PHOENIX_COLLECTOR_HTTP_ENDPOINT": "http://test.endpoint",
|
||||
},
|
||||
)
|
||||
def test_get_arize_phoenix_config_http(self):
|
||||
# Call the function to get the configuration
|
||||
config = ArizePhoenixLogger.get_arize_phoenix_config()
|
||||
|
||||
# Verify the configuration
|
||||
self.assertEqual(
|
||||
config.otlp_auth_headers, "Authorization=Bearer%20test_api_key"
|
||||
)
|
||||
self.assertEqual(config.endpoint, "http://test.endpoint")
|
||||
self.assertEqual(config.protocol, "otlp_http")
|
||||
|
||||
@patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"PHOENIX_API_KEY": "test_api_key",
|
||||
"PHOENIX_COLLECTOR_ENDPOINT": "grpc://test.endpoint",
|
||||
},
|
||||
)
|
||||
def test_get_arize_phoenix_config_grpc(self):
|
||||
# Call the function to get the configuration
|
||||
config = ArizePhoenixLogger.get_arize_phoenix_config()
|
||||
|
||||
# Verify the configuration
|
||||
self.assertEqual(
|
||||
config.otlp_auth_headers, "Authorization=Bearer%20test_api_key"
|
||||
)
|
||||
self.assertEqual(config.endpoint, "grpc://test.endpoint")
|
||||
self.assertEqual(config.protocol, "otlp_grpc")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@@ -0,0 +1,284 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
# Adds the grandparent directory to sys.path to allow importing project modules
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.integrations._types.open_inference import (
|
||||
MessageAttributes,
|
||||
SpanAttributes,
|
||||
ToolCallAttributes,
|
||||
)
|
||||
from litellm.integrations.arize.arize import ArizeLogger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import Choices, StandardCallbackDynamicParams
|
||||
|
||||
|
||||
def test_arize_set_attributes():
|
||||
"""
|
||||
Test setting attributes for Arize, including all custom LLM attributes.
|
||||
Ensures that the correct span attributes are being added during a request.
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
span = MagicMock() # Mocked tracing span to test attribute setting
|
||||
|
||||
# Construct kwargs to simulate a real LLM request scenario
|
||||
kwargs = {
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "Basic Request Content"}],
|
||||
"standard_logging_object": {
|
||||
"model_parameters": {"user": "test_user"},
|
||||
"metadata": {"key_1": "value_1", "key_2": None},
|
||||
"call_type": "completion",
|
||||
},
|
||||
"optional_params": {
|
||||
"max_tokens": "100",
|
||||
"temperature": "1",
|
||||
"top_p": "5",
|
||||
"stream": False,
|
||||
"user": "test_user",
|
||||
"tools": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Fetches weather details.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City name",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
}
|
||||
],
|
||||
"functions": [{"name": "get_weather"}, {"name": "get_stock_price"}],
|
||||
},
|
||||
"litellm_params": {"custom_llm_provider": "openai"},
|
||||
}
|
||||
|
||||
# Simulated LLM response object
|
||||
response_obj = ModelResponse(
|
||||
usage={"total_tokens": 100, "completion_tokens": 60, "prompt_tokens": 40},
|
||||
choices=[
|
||||
Choices(message={"role": "assistant", "content": "Basic Response Content"})
|
||||
],
|
||||
model="gpt-4o",
|
||||
id="chatcmpl-ID",
|
||||
)
|
||||
|
||||
# Apply attribute setting via ArizeLogger
|
||||
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
|
||||
|
||||
# Validate that the expected number of attributes were set
|
||||
assert span.set_attribute.call_count == 28
|
||||
|
||||
# Metadata attached to the span
|
||||
span.set_attribute.assert_any_call(
|
||||
SpanAttributes.METADATA, json.dumps({"key_1": "value_1", "key_2": None})
|
||||
)
|
||||
|
||||
# Basic LLM information
|
||||
span.set_attribute.assert_any_call(SpanAttributes.LLM_MODEL_NAME, "gpt-4o")
|
||||
span.set_attribute.assert_any_call("llm.request.type", "completion")
|
||||
span.set_attribute.assert_any_call(SpanAttributes.LLM_PROVIDER, "openai")
|
||||
|
||||
# LLM generation parameters
|
||||
span.set_attribute.assert_any_call("llm.request.max_tokens", "100")
|
||||
span.set_attribute.assert_any_call("llm.request.temperature", "1")
|
||||
span.set_attribute.assert_any_call("llm.request.top_p", "5")
|
||||
|
||||
# Streaming and user info
|
||||
span.set_attribute.assert_any_call("llm.is_streaming", "False")
|
||||
span.set_attribute.assert_any_call("llm.user", "test_user")
|
||||
|
||||
# Response metadata
|
||||
span.set_attribute.assert_any_call("llm.response.id", "chatcmpl-ID")
|
||||
span.set_attribute.assert_any_call("llm.response.model", "gpt-4o")
|
||||
span.set_attribute.assert_any_call(SpanAttributes.OPENINFERENCE_SPAN_KIND, "LLM")
|
||||
|
||||
# Request message content and metadata
|
||||
span.set_attribute.assert_any_call(
|
||||
SpanAttributes.INPUT_VALUE, "Basic Request Content"
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}",
|
||||
"user",
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
"Basic Request Content",
|
||||
)
|
||||
|
||||
# Tool call definitions and function names
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{SpanAttributes.LLM_TOOLS}.0.{SpanAttributes.TOOL_NAME}", "get_weather"
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{SpanAttributes.LLM_TOOLS}.0.{SpanAttributes.TOOL_DESCRIPTION}",
|
||||
"Fetches weather details.",
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{SpanAttributes.LLM_TOOLS}.0.{SpanAttributes.TOOL_PARAMETERS}",
|
||||
json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "City name"}
|
||||
},
|
||||
"required": ["location"],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Tool calls captured from optional_params
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.0.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}",
|
||||
"get_weather",
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.1.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}",
|
||||
"get_stock_price",
|
||||
)
|
||||
|
||||
# Invocation parameters
|
||||
span.set_attribute.assert_any_call(
|
||||
SpanAttributes.LLM_INVOCATION_PARAMETERS, '{"user": "test_user"}'
|
||||
)
|
||||
|
||||
# User ID
|
||||
span.set_attribute.assert_any_call(SpanAttributes.USER_ID, "test_user")
|
||||
|
||||
# Output message content
|
||||
span.set_attribute.assert_any_call(
|
||||
SpanAttributes.OUTPUT_VALUE, "Basic Response Content"
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}",
|
||||
"assistant",
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
"Basic Response Content",
|
||||
)
|
||||
|
||||
# Token counts
|
||||
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_TOTAL, 100)
|
||||
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 60)
|
||||
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 40)
|
||||
|
||||
|
||||
class TestArizeLogger(CustomLogger):
|
||||
"""
|
||||
Custom logger implementation to capture standard_callback_dynamic_params.
|
||||
Used to verify that dynamic config keys are being passed to callbacks.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.standard_callback_dynamic_params: Optional[
|
||||
StandardCallbackDynamicParams
|
||||
] = None
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
# Capture dynamic params and print them for verification
|
||||
print("logged kwargs", json.dumps(kwargs, indent=4, default=str))
|
||||
self.standard_callback_dynamic_params = kwargs.get(
|
||||
"standard_callback_dynamic_params"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arize_dynamic_params():
|
||||
"""
|
||||
Test to ensure that dynamic Arize keys (API key and space key)
|
||||
are received inside the callback logger at runtime.
|
||||
"""
|
||||
test_arize_logger = TestArizeLogger()
|
||||
litellm.callbacks = [test_arize_logger]
|
||||
|
||||
# Perform a mocked async completion call to trigger logging
|
||||
await litellm.acompletion(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": "Basic Request Content"}],
|
||||
mock_response="test",
|
||||
arize_api_key="test_api_key_dynamic",
|
||||
arize_space_key="test_space_key_dynamic",
|
||||
)
|
||||
|
||||
# Allow for async propagation
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Assert dynamic parameters were received in the callback
|
||||
assert test_arize_logger.standard_callback_dynamic_params is not None
|
||||
assert (
|
||||
test_arize_logger.standard_callback_dynamic_params.get("arize_api_key")
|
||||
== "test_api_key_dynamic"
|
||||
)
|
||||
assert (
|
||||
test_arize_logger.standard_callback_dynamic_params.get("arize_space_key")
|
||||
== "test_space_key_dynamic"
|
||||
)
|
||||
|
||||
|
||||
def test_construct_dynamic_arize_headers():
|
||||
"""
|
||||
Test the construct_dynamic_arize_headers method with various input scenarios.
|
||||
Ensures that dynamic Arize headers are properly constructed from callback parameters.
|
||||
"""
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
# Test with all parameters present
|
||||
dynamic_params_full = StandardCallbackDynamicParams(
|
||||
arize_api_key="test_api_key",
|
||||
arize_space_id="test_space_id"
|
||||
)
|
||||
arize_logger = ArizeLogger()
|
||||
|
||||
headers = arize_logger.construct_dynamic_otel_headers(dynamic_params_full)
|
||||
expected_headers = {
|
||||
"api_key": "test_api_key",
|
||||
"arize-space-id": "test_space_id"
|
||||
}
|
||||
assert headers == expected_headers
|
||||
|
||||
# Test with only space_id
|
||||
dynamic_params_space_id_only = StandardCallbackDynamicParams(
|
||||
arize_space_id="test_space_id"
|
||||
)
|
||||
|
||||
headers = arize_logger.construct_dynamic_otel_headers(dynamic_params_space_id_only)
|
||||
expected_headers = {
|
||||
"arize-space-id": "test_space_id"
|
||||
}
|
||||
assert headers == expected_headers
|
||||
|
||||
# Test with empty parameters dict
|
||||
dynamic_params_empty = StandardCallbackDynamicParams()
|
||||
|
||||
headers = arize_logger.construct_dynamic_otel_headers(dynamic_params_empty)
|
||||
assert headers == {}
|
||||
|
||||
# test with space key and api key
|
||||
dynamic_params_space_key_and_api_key = StandardCallbackDynamicParams(
|
||||
arize_space_key="test_space_key",
|
||||
arize_api_key="test_api_key"
|
||||
)
|
||||
headers = arize_logger.construct_dynamic_otel_headers(dynamic_params_space_key_and_api_key)
|
||||
expected_headers = {
|
||||
"arize-space-id": "test_space_key",
|
||||
"api_key": "test_api_key"
|
||||
}
|
Reference in New Issue
Block a user