Files
Homelab/Development/litellm/tests/mcp_tests/test_mcp_logging.py

354 lines
15 KiB
Python

import os
import sys
import pytest
import asyncio
import json
from typing import Optional
from unittest.mock import AsyncMock, patch
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.types.utils import StandardLoggingPayload
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._experimental.mcp_server.server import (
mcp_server_tool_call,
)
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
MCPServerManager,
)
from litellm.types.mcp import MCPPostCallResponseObject
from litellm.types.utils import HiddenParams
from mcp.types import Tool as MCPTool, CallToolResult, TextContent
class TestMCPLogger(CustomLogger):
def __init__(self):
self.standard_logging_payload = None
super().__init__()
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print("success event")
self.standard_logging_payload = kwargs.get("standard_logging_object", None)
print(f"Captured standard_logging_payload: {self.standard_logging_payload}")
@pytest.mark.asyncio
async def test_mcp_cost_tracking():
# Create a mock tool call result
litellm.logging_callback_manager._reset_all_callbacks()
mock_result = CallToolResult(
content=[TextContent(type="text", text="Test response")],
isError=False
)
# Create a mock MCPClient
mock_client = AsyncMock()
mock_client.call_tool = AsyncMock(return_value=mock_result)
mock_client.list_tools = AsyncMock(return_value=[
MCPTool(
name="add_tools",
description="Test tool",
inputSchema={"type": "object", "properties": {"test": {"type": "string"}}}
)
])
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
# Mock the MCPClient constructor
def mock_client_constructor(*args, **kwargs):
return mock_client
# Initialize the server manager
local_mcp_server_manager = MCPServerManager()
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
# Load the server config
local_mcp_server_manager.load_servers_from_config(
mcp_servers_config={
"zapier_gmail_server": {
"url": os.getenv("ZAPIER_MCP_HTTPS_SERVER_URL"),
"mcp_info": {
"mcp_server_cost_info": {
"default_cost_per_query": 1.2,
}
}
}
}
)
# Set up the test logger
test_logger = TestMCPLogger()
litellm.callbacks = [test_logger]
# Initialize the tool mapping
await local_mcp_server_manager._initialize_tool_name_to_mcp_server_name_mapping()
# Patch the global manager in both modules where it's used
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager', local_mcp_server_manager), \
patch('litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager', local_mcp_server_manager):
print("tool_name_to_mcp_server_name_mapping", local_mcp_server_manager.tool_name_to_mcp_server_name_mapping)
# Manually add the tool mapping to ensure it's available (since mocking might not capture it properly)
local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["add_tools"] = "zapier_gmail_server"
local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["zapier_gmail_server-add_tools"] = "zapier_gmail_server"
# Call mcp tool
response = await mcp_server_tool_call(
name="zapier_gmail_server-add_tools", # Use correct prefixed name with - separator
arguments={
"test": "test"
}
)
# wait 1-2 seconds for logging to be processed
await asyncio.sleep(2)
logged_standard_logging_payload = test_logger.standard_logging_payload
print("logged_standard_logging_payload", json.dumps(logged_standard_logging_payload, indent=4))
# Add assertions
assert response is not None
response_list = list(response) # Convert iterable to list
assert len(response_list) == 1
assert isinstance(response_list[0], TextContent)
assert response_list[0].text == "Test response"
# Verify client methods were called
mock_client.__aenter__.assert_called()
mock_client.call_tool.assert_called_once()
######
# verify response cost is 1.2 as set on default_cost_per_query
# Critical - the cost is tracked as $1.2
assert logged_standard_logging_payload is not None, "Standard logging payload should not be None"
assert logged_standard_logging_payload["response_cost"] == 1.2
@pytest.mark.asyncio
async def test_mcp_cost_tracking_per_tool():
"""Test that individual tool costs are tracked correctly when tool_name_to_cost_per_query is configured"""
# Create a mock tool call result
litellm.logging_callback_manager._reset_all_callbacks()
mock_result = CallToolResult(
content=[TextContent(type="text", text="Test response")],
isError=False
)
# Create a mock MCPClient
mock_client = AsyncMock()
mock_client.call_tool = AsyncMock(return_value=mock_result)
mock_client.list_tools = AsyncMock(return_value=[
MCPTool(
name="expensive_tool",
description="Expensive tool",
inputSchema={"type": "object", "properties": {"data": {"type": "string"}}}
),
MCPTool(
name="cheap_tool",
description="Cheap tool",
inputSchema={"type": "object", "properties": {"data": {"type": "string"}}}
)
])
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client.disconnect = AsyncMock(return_value=None)
# Mock the MCPClient constructor
def mock_client_constructor(*args, **kwargs):
return mock_client
# Initialize the server manager
local_mcp_server_manager = MCPServerManager()
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
# Load the server config with per-tool costs
local_mcp_server_manager.load_servers_from_config(
mcp_servers_config={
"test_server": {
"url": os.getenv("ZAPIER_MCP_HTTPS_SERVER_URL"),
"mcp_info": {
"mcp_server_cost_info": {
"default_cost_per_query": 0.5, # Default cost
"tool_name_to_cost_per_query": {
"expensive_tool": 5.0, # High cost tool
"cheap_tool": 0.1 # Low cost tool
}
}
}
}
}
)
# Set up the test logger
test_logger = TestMCPLogger()
litellm.callbacks = [test_logger]
# Initialize the tool mapping
await local_mcp_server_manager._initialize_tool_name_to_mcp_server_name_mapping()
# Manually add the tool mapping to ensure it's available (since mocking might not capture it properly)
local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["expensive_tool"] = "test_server"
local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["test_server-expensive_tool"] = "test_server"
local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["cheap_tool"] = "test_server"
local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["test_server-cheap_tool"] = "test_server"
# Patch the global manager in both modules where it's used
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager', local_mcp_server_manager), \
patch('litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager', local_mcp_server_manager):
print("tool_name_to_mcp_server_name_mapping", local_mcp_server_manager.tool_name_to_mcp_server_name_mapping)
# Test 1: Call expensive_tool - should cost 5.0
response1 = await mcp_server_tool_call(
name="test_server-expensive_tool", # Use correct prefixed name with - separator
arguments={
"data": "test_expensive"
}
)
# wait for logging to be processed
await asyncio.sleep(2)
logged_standard_logging_payload_1 = test_logger.standard_logging_payload
print("logged_standard_logging_payload_1", json.dumps(logged_standard_logging_payload_1, indent=4))
# Verify expensive tool cost
assert logged_standard_logging_payload_1 is not None, "Standard logging payload 1 should not be None"
assert logged_standard_logging_payload_1["response_cost"] == 5.0
# Reset logger for second test
test_logger.standard_logging_payload = None
# Test 2: Call cheap_tool - should cost 0.1
response2 = await mcp_server_tool_call(
name="test_server-cheap_tool", # Use correct prefixed name with - separator
arguments={
"data": "test_cheap"
}
)
# wait for logging to be processed
await asyncio.sleep(2)
logged_standard_logging_payload_2 = test_logger.standard_logging_payload
print("logged_standard_logging_payload_2", json.dumps(logged_standard_logging_payload_2, indent=4))
# Verify cheap tool cost
assert logged_standard_logging_payload_2 is not None, "Standard logging payload 2 should not be None"
assert logged_standard_logging_payload_2["response_cost"] == 0.1
# Add basic response assertions
assert response1 is not None
assert response2 is not None
response_list_1 = list(response1)
response_list_2 = list(response2)
assert len(response_list_1) == 1
assert len(response_list_2) == 1
assert isinstance(response_list_1[0], TextContent)
assert isinstance(response_list_2[0], TextContent)
assert response_list_1[0].text == "Test response"
assert response_list_2[0].text == "Test response"
# Verify client methods were called twice
assert mock_client.call_tool.call_count == 2
class MCPLoggerHook(CustomLogger):
def __init__(self):
self.standard_logging_payload = None
super().__init__()
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print("success event")
self.standard_logging_payload = kwargs.get("standard_logging_object", None)
print(f"Captured standard_logging_payload: {self.standard_logging_payload}")
async def async_post_mcp_tool_call_hook(self, kwargs, response_obj: MCPPostCallResponseObject, start_time, end_time) -> Optional[MCPPostCallResponseObject]:
print("post mcp tool call response_obj", response_obj)
# update the MCPPostCallResponseObject with the response_cost
response_obj.hidden_params.response_cost = 1.42
return response_obj
@pytest.mark.asyncio
async def test_mcp_tool_call_hook():
# Create a mock tool call result
litellm.logging_callback_manager._reset_all_callbacks()
mock_result = CallToolResult(
content=[TextContent(type="text", text="Test response")],
isError=False
)
# Create a mock MCPClient
mock_client = AsyncMock()
mock_client.call_tool = AsyncMock(return_value=mock_result)
mock_client.list_tools = AsyncMock(return_value=[
MCPTool(
name="add_tools",
description="Test tool",
inputSchema={"type": "object", "properties": {"test": {"type": "string"}}}
)
])
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client.disconnect = AsyncMock(return_value=None)
# Mock the MCPClient constructor
def mock_client_constructor(*args, **kwargs):
return mock_client
# Initialize the server manager
local_mcp_server_manager = MCPServerManager()
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
# Load the server config
local_mcp_server_manager.load_servers_from_config(
mcp_servers_config={
"zapier_gmail_server": {
"url": os.getenv("ZAPIER_MCP_HTTPS_SERVER_URL"),
}
}
)
# Set up the test logger
test_logger = MCPLoggerHook()
litellm.callbacks = [test_logger]
# Initialize the tool mapping
await local_mcp_server_manager._initialize_tool_name_to_mcp_server_name_mapping()
# Manually add the tool mapping to ensure it's available (since mocking might not capture it properly)
local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["add_tools"] = "zapier_gmail_server"
local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["zapier_gmail_server-add_tools"] = "zapier_gmail_server"
# Patch the global manager in both modules where it's used
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager', local_mcp_server_manager), \
patch('litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager', local_mcp_server_manager):
print("tool_name_to_mcp_server_name_mapping", local_mcp_server_manager.tool_name_to_mcp_server_name_mapping)
# Call mcp tool using the correct separator format (- not /)
response = await mcp_server_tool_call(
name="zapier_gmail_server-add_tools", # Use correct prefixed name with - separator
arguments={
"test": "test"
}
)
# wait 1-2 seconds for logging to be processed
await asyncio.sleep(2)
# check logged standard logging payload
logged_standard_logging_payload = test_logger.standard_logging_payload
print("logged_standard_logging_payload", json.dumps(logged_standard_logging_payload, indent=4))
assert logged_standard_logging_payload is not None, "Standard logging payload should not be None"
assert logged_standard_logging_payload["response_cost"] == 1.42