Files
Homelab/Development/litellm/tests/test_litellm/integrations/test_mlflow.py

113 lines
3.7 KiB
Python

import asyncio
import os
import sys
from unittest.mock import MagicMock, patch
# Adds the grandparent directory to sys.path to allow importing project modules
sys.path.insert(0, os.path.abspath("../.."))
import pytest
import litellm
@pytest.mark.asyncio
async def test_mlflow_request_tags_functionality():
"""Test that request_tags are properly extracted and transformed into tags for MLflow traces."""
# Mock MLflow client and dependencies
mock_client = MagicMock()
mock_span = MagicMock()
mock_span.parent_id = None # Simulate root trace
mock_span.request_id = "test_trace_id"
mock_client.start_trace.return_value = mock_span
# Mock all MLflow-related imports to avoid requiring MLflow as a dependency
mock_mlflow_tracking = MagicMock()
mock_mlflow_tracking.MlflowClient = MagicMock(return_value=mock_client)
mock_mlflow_entities = MagicMock()
mock_mlflow_entities.SpanStatusCode.OK = "OK"
mock_mlflow_entities.SpanStatusCode.ERROR = "ERROR"
mock_mlflow_entities.SpanType.LLM = "LLM"
mock_mlflow = MagicMock()
mock_mlflow.get_current_active_span.return_value = None
with patch.dict('sys.modules', {
'mlflow': mock_mlflow,
'mlflow.tracking': mock_mlflow_tracking,
'mlflow.entities': mock_mlflow_entities,
'mlflow.tracing.utils': MagicMock(),
}):
# Now we can safely import MlflowLogger
from litellm.integrations.mlflow import MlflowLogger
# Create MlflowLogger instance
mlflow_logger = MlflowLogger()
litellm.callbacks = [mlflow_logger]
# Test completion with request_tags
await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "test message"}],
mock_response="test response",
metadata={
"tags": ["tag1", "tag2", "production"]
}
)
# Allow time for async processing
await asyncio.sleep(1)
# Verify start_trace was called with tags parameter
assert mock_client.start_trace.called, "start_trace should have been called"
# Get the call arguments
call_args = mock_client.start_trace.call_args
assert call_args is not None, "start_trace call args should not be None"
# Check that tags parameter was included and properly transformed
tags_param = call_args.kwargs.get('tags', {})
expected_tags = {"tag1": "", "tag2": "", "production": ""}
assert tags_param == expected_tags, f"Expected tags {expected_tags}, got {tags_param}"
def test_mlflow_token_usage_attribute_structure():
"""Ensure token usage attributes are formatted with mlflow.chat.tokenUsage."""
mock_mlflow_tracking = MagicMock()
mock_mlflow_tracking.MlflowClient = MagicMock()
with patch.dict(
"sys.modules",
{
"mlflow": MagicMock(),
"mlflow.tracking": mock_mlflow_tracking,
"mlflow.tracing.utils": MagicMock(),
},
):
from litellm.integrations.mlflow import MlflowLogger
mlflow_logger = MlflowLogger()
attrs = mlflow_logger._extract_attributes( # type: ignore
{
"litellm_call_id": "123",
"call_type": "completion",
"model": "gpt-3.5-turbo",
"standard_logging_object": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12,
},
}
)
assert attrs["mlflow.chat.tokenUsage"] == {
"input_tokens": 5,
"output_tokens": 7,
"total_tokens": 12,
}