import os import sys import pytest import asyncio import json from typing import Optional from unittest.mock import AsyncMock, patch sys.path.insert( 0, os.path.abspath("../../..") ) # Adds the parent directory to the system path import litellm from litellm.types.utils import StandardLoggingPayload from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._experimental.mcp_server.server import ( mcp_server_tool_call, ) from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( MCPServerManager, ) from litellm.types.mcp import MCPPostCallResponseObject from litellm.types.utils import HiddenParams from mcp.types import Tool as MCPTool, CallToolResult, TextContent class TestMCPLogger(CustomLogger): def __init__(self): self.standard_logging_payload = None super().__init__() async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): print("success event") self.standard_logging_payload = kwargs.get("standard_logging_object", None) print(f"Captured standard_logging_payload: {self.standard_logging_payload}") @pytest.mark.asyncio async def test_mcp_cost_tracking(): # Create a mock tool call result litellm.logging_callback_manager._reset_all_callbacks() mock_result = CallToolResult( content=[TextContent(type="text", text="Test response")], isError=False ) # Create a mock MCPClient mock_client = AsyncMock() mock_client.call_tool = AsyncMock(return_value=mock_result) mock_client.list_tools = AsyncMock(return_value=[ MCPTool( name="add_tools", description="Test tool", inputSchema={"type": "object", "properties": {"test": {"type": "string"}}} ) ]) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) # Mock the MCPClient constructor def mock_client_constructor(*args, **kwargs): return mock_client # Initialize the server manager local_mcp_server_manager = MCPServerManager() with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor): # Load the server config local_mcp_server_manager.load_servers_from_config( mcp_servers_config={ "zapier_gmail_server": { "url": os.getenv("ZAPIER_MCP_HTTPS_SERVER_URL"), "mcp_info": { "mcp_server_cost_info": { "default_cost_per_query": 1.2, } } } } ) # Set up the test logger test_logger = TestMCPLogger() litellm.callbacks = [test_logger] # Initialize the tool mapping await local_mcp_server_manager._initialize_tool_name_to_mcp_server_name_mapping() # Patch the global manager in both modules where it's used with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager', local_mcp_server_manager), \ patch('litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager', local_mcp_server_manager): print("tool_name_to_mcp_server_name_mapping", local_mcp_server_manager.tool_name_to_mcp_server_name_mapping) # Manually add the tool mapping to ensure it's available (since mocking might not capture it properly) local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["add_tools"] = "zapier_gmail_server" local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["zapier_gmail_server-add_tools"] = "zapier_gmail_server" # Call mcp tool response = await mcp_server_tool_call( name="zapier_gmail_server-add_tools", # Use correct prefixed name with - separator arguments={ "test": "test" } ) # wait 1-2 seconds for logging to be processed await asyncio.sleep(2) logged_standard_logging_payload = test_logger.standard_logging_payload print("logged_standard_logging_payload", json.dumps(logged_standard_logging_payload, indent=4)) # Add assertions assert response is not None response_list = list(response) # Convert iterable to list assert len(response_list) == 1 assert isinstance(response_list[0], TextContent) assert response_list[0].text == "Test response" # Verify client methods were called mock_client.__aenter__.assert_called() mock_client.call_tool.assert_called_once() ###### # verify response cost is 1.2 as set on default_cost_per_query # Critical - the cost is tracked as $1.2 assert logged_standard_logging_payload is not None, "Standard logging payload should not be None" assert logged_standard_logging_payload["response_cost"] == 1.2 @pytest.mark.asyncio async def test_mcp_cost_tracking_per_tool(): """Test that individual tool costs are tracked correctly when tool_name_to_cost_per_query is configured""" # Create a mock tool call result litellm.logging_callback_manager._reset_all_callbacks() mock_result = CallToolResult( content=[TextContent(type="text", text="Test response")], isError=False ) # Create a mock MCPClient mock_client = AsyncMock() mock_client.call_tool = AsyncMock(return_value=mock_result) mock_client.list_tools = AsyncMock(return_value=[ MCPTool( name="expensive_tool", description="Expensive tool", inputSchema={"type": "object", "properties": {"data": {"type": "string"}}} ), MCPTool( name="cheap_tool", description="Cheap tool", inputSchema={"type": "object", "properties": {"data": {"type": "string"}}} ) ]) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) mock_client.disconnect = AsyncMock(return_value=None) # Mock the MCPClient constructor def mock_client_constructor(*args, **kwargs): return mock_client # Initialize the server manager local_mcp_server_manager = MCPServerManager() with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor): # Load the server config with per-tool costs local_mcp_server_manager.load_servers_from_config( mcp_servers_config={ "test_server": { "url": os.getenv("ZAPIER_MCP_HTTPS_SERVER_URL"), "mcp_info": { "mcp_server_cost_info": { "default_cost_per_query": 0.5, # Default cost "tool_name_to_cost_per_query": { "expensive_tool": 5.0, # High cost tool "cheap_tool": 0.1 # Low cost tool } } } } } ) # Set up the test logger test_logger = TestMCPLogger() litellm.callbacks = [test_logger] # Initialize the tool mapping await local_mcp_server_manager._initialize_tool_name_to_mcp_server_name_mapping() # Manually add the tool mapping to ensure it's available (since mocking might not capture it properly) local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["expensive_tool"] = "test_server" local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["test_server-expensive_tool"] = "test_server" local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["cheap_tool"] = "test_server" local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["test_server-cheap_tool"] = "test_server" # Patch the global manager in both modules where it's used with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager', local_mcp_server_manager), \ patch('litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager', local_mcp_server_manager): print("tool_name_to_mcp_server_name_mapping", local_mcp_server_manager.tool_name_to_mcp_server_name_mapping) # Test 1: Call expensive_tool - should cost 5.0 response1 = await mcp_server_tool_call( name="test_server-expensive_tool", # Use correct prefixed name with - separator arguments={ "data": "test_expensive" } ) # wait for logging to be processed await asyncio.sleep(2) logged_standard_logging_payload_1 = test_logger.standard_logging_payload print("logged_standard_logging_payload_1", json.dumps(logged_standard_logging_payload_1, indent=4)) # Verify expensive tool cost assert logged_standard_logging_payload_1 is not None, "Standard logging payload 1 should not be None" assert logged_standard_logging_payload_1["response_cost"] == 5.0 # Reset logger for second test test_logger.standard_logging_payload = None # Test 2: Call cheap_tool - should cost 0.1 response2 = await mcp_server_tool_call( name="test_server-cheap_tool", # Use correct prefixed name with - separator arguments={ "data": "test_cheap" } ) # wait for logging to be processed await asyncio.sleep(2) logged_standard_logging_payload_2 = test_logger.standard_logging_payload print("logged_standard_logging_payload_2", json.dumps(logged_standard_logging_payload_2, indent=4)) # Verify cheap tool cost assert logged_standard_logging_payload_2 is not None, "Standard logging payload 2 should not be None" assert logged_standard_logging_payload_2["response_cost"] == 0.1 # Add basic response assertions assert response1 is not None assert response2 is not None response_list_1 = list(response1) response_list_2 = list(response2) assert len(response_list_1) == 1 assert len(response_list_2) == 1 assert isinstance(response_list_1[0], TextContent) assert isinstance(response_list_2[0], TextContent) assert response_list_1[0].text == "Test response" assert response_list_2[0].text == "Test response" # Verify client methods were called twice assert mock_client.call_tool.call_count == 2 class MCPLoggerHook(CustomLogger): def __init__(self): self.standard_logging_payload = None super().__init__() async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): print("success event") self.standard_logging_payload = kwargs.get("standard_logging_object", None) print(f"Captured standard_logging_payload: {self.standard_logging_payload}") async def async_post_mcp_tool_call_hook(self, kwargs, response_obj: MCPPostCallResponseObject, start_time, end_time) -> Optional[MCPPostCallResponseObject]: print("post mcp tool call response_obj", response_obj) # update the MCPPostCallResponseObject with the response_cost response_obj.hidden_params.response_cost = 1.42 return response_obj @pytest.mark.asyncio async def test_mcp_tool_call_hook(): # Create a mock tool call result litellm.logging_callback_manager._reset_all_callbacks() mock_result = CallToolResult( content=[TextContent(type="text", text="Test response")], isError=False ) # Create a mock MCPClient mock_client = AsyncMock() mock_client.call_tool = AsyncMock(return_value=mock_result) mock_client.list_tools = AsyncMock(return_value=[ MCPTool( name="add_tools", description="Test tool", inputSchema={"type": "object", "properties": {"test": {"type": "string"}}} ) ]) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) mock_client.disconnect = AsyncMock(return_value=None) # Mock the MCPClient constructor def mock_client_constructor(*args, **kwargs): return mock_client # Initialize the server manager local_mcp_server_manager = MCPServerManager() with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor): # Load the server config local_mcp_server_manager.load_servers_from_config( mcp_servers_config={ "zapier_gmail_server": { "url": os.getenv("ZAPIER_MCP_HTTPS_SERVER_URL"), } } ) # Set up the test logger test_logger = MCPLoggerHook() litellm.callbacks = [test_logger] # Initialize the tool mapping await local_mcp_server_manager._initialize_tool_name_to_mcp_server_name_mapping() # Manually add the tool mapping to ensure it's available (since mocking might not capture it properly) local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["add_tools"] = "zapier_gmail_server" local_mcp_server_manager.tool_name_to_mcp_server_name_mapping["zapier_gmail_server-add_tools"] = "zapier_gmail_server" # Patch the global manager in both modules where it's used with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager', local_mcp_server_manager), \ patch('litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager', local_mcp_server_manager): print("tool_name_to_mcp_server_name_mapping", local_mcp_server_manager.tool_name_to_mcp_server_name_mapping) # Call mcp tool using the correct separator format (- not /) response = await mcp_server_tool_call( name="zapier_gmail_server-add_tools", # Use correct prefixed name with - separator arguments={ "test": "test" } ) # wait 1-2 seconds for logging to be processed await asyncio.sleep(2) # check logged standard logging payload logged_standard_logging_payload = test_logger.standard_logging_payload print("logged_standard_logging_payload", json.dumps(logged_standard_logging_payload, indent=4)) assert logged_standard_logging_payload is not None, "Standard logging payload should not be None" assert logged_standard_logging_payload["response_cost"] == 1.42