""" Test file for MCP Hook Architecture This file demonstrates the new MCP hook system with comprehensive examples and validation tests. """ import asyncio import pytest from datetime import datetime from typing import Optional from litellm.integrations.custom_logger import CustomLogger from litellm.types.mcp import ( MCPPreCallRequestObject, MCPPreCallResponseObject, MCPDuringCallRequestObject, MCPDuringCallResponseObject, MCPPostCallResponseObject, ) from litellm.types.llms.base import HiddenParams class TestMCPAccessControlHook(CustomLogger): """Test hook for access control functionality""" def __init__(self): self.allowed_tools = {"github/create_issue", "zapier/send_email"} self.blocked_users = {"user123", "user456"} self.call_count = 0 async def async_pre_mcp_tool_call_hook( self, kwargs, request_obj: MCPPreCallRequestObject, start_time, end_time ) -> Optional[MCPPreCallResponseObject]: """Test access control validation""" self.call_count += 1 tool_name = request_obj.tool_name user_id = kwargs.get("user_api_key_auth", {}).get("user_id") # Check if user is blocked if user_id in self.blocked_users: return MCPPreCallResponseObject( should_proceed=False, error_message=f"User {user_id} is not authorized to use MCP tools" ) # Check if tool is allowed if tool_name not in self.allowed_tools: return MCPPreCallResponseObject( should_proceed=False, error_message=f"Tool {tool_name} is not authorized" ) return None # Allow execution to proceed class TestMCPCostTrackingHook(CustomLogger): """Test hook for cost tracking functionality""" def __init__(self): self.cost_map = { "github/create_issue": 0.10, "zapier/send_email": 0.05, "default": 0.01 } self.call_count = 0 async def async_post_mcp_tool_call_hook( self, kwargs, response_obj: MCPPostCallResponseObject, start_time, end_time ) -> Optional[MCPPostCallResponseObject]: """Test cost calculation after tool execution""" self.call_count += 1 tool_name = kwargs.get("name", "") cost = self.cost_map.get(tool_name, self.cost_map["default"]) # Set the response cost response_obj.hidden_params.response_cost = cost return response_obj class TestMCPMonitoringHook(CustomLogger): """Test hook for real-time monitoring functionality""" def __init__(self): self.max_execution_time = 30.0 # seconds self.call_count = 0 async def async_during_mcp_tool_call_hook( self, kwargs, request_obj: MCPDuringCallRequestObject, start_time, end_time ) -> Optional[MCPDuringCallResponseObject]: """Test execution time monitoring""" self.call_count += 1 tool_name = request_obj.tool_name execution_time = (datetime.now() - start_time).total_seconds() # Check if execution is taking too long if execution_time > self.max_execution_time: return MCPDuringCallResponseObject( should_continue=False, error_message=f"Tool {tool_name} execution timeout after {execution_time}s" ) return None # Allow execution to continue class TestMCPArgumentValidationHook(CustomLogger): """Test hook for argument validation functionality""" def __init__(self): self.call_count = 0 async def async_pre_mcp_tool_call_hook( self, kwargs, request_obj: MCPPreCallRequestObject, start_time, end_time ) -> Optional[MCPPreCallResponseObject]: """Test argument validation and sanitization""" self.call_count += 1 tool_name = request_obj.tool_name arguments = request_obj.arguments.copy() # Create a copy to modify # Example: Validate GitHub issue creation if tool_name == "github/create_issue": if not arguments.get("title"): return MCPPreCallResponseObject( should_proceed=False, error_message="GitHub issue title is required" ) # Sanitize the title title = arguments["title"] if len(title) > 100: title = title[:97] + "..." arguments["title"] = title # Example: Validate email sending elif tool_name == "zapier/send_email": if not arguments.get("to"): return MCPPreCallResponseObject( should_proceed=False, error_message="Email recipient is required" ) return MCPPreCallResponseObject( should_proceed=True, modified_arguments=arguments ) # Test fixtures @pytest.fixture def access_control_hook(): return TestMCPAccessControlHook() @pytest.fixture def cost_tracking_hook(): return TestMCPCostTrackingHook() @pytest.fixture def monitoring_hook(): return TestMCPMonitoringHook() @pytest.fixture def argument_validation_hook(): return TestMCPArgumentValidationHook() # Test cases class TestMCPHooks: """Test cases for MCP hook functionality""" @pytest.mark.asyncio async def test_access_control_hook_allowed_tool(self, access_control_hook): """Test that allowed tools pass validation""" kwargs = { "user_api_key_auth": {"user_id": "user789"}, "name": "github/create_issue" } request_obj = MCPPreCallRequestObject( tool_name="github/create_issue", arguments={"title": "Test issue"}, user_api_key_auth={"user_id": "user789"} ) result = await access_control_hook.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now() ) assert result is None # Should allow execution assert access_control_hook.call_count == 1 @pytest.mark.asyncio async def test_access_control_hook_blocked_user(self, access_control_hook): """Test that blocked users are rejected""" kwargs = { "user_api_key_auth": {"user_id": "user123"}, "name": "github/create_issue" } request_obj = MCPPreCallRequestObject( tool_name="github/create_issue", arguments={"title": "Test issue"}, user_api_key_auth={"user_id": "user123"} ) result = await access_control_hook.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now() ) assert result is not None assert result.should_proceed is False assert "not authorized" in result.error_message @pytest.mark.asyncio async def test_access_control_hook_unauthorized_tool(self, access_control_hook): """Test that unauthorized tools are rejected""" kwargs = { "user_api_key_auth": {"user_id": "user789"}, "name": "unauthorized_tool" } request_obj = MCPPreCallRequestObject( tool_name="unauthorized_tool", arguments={"param": "value"}, user_api_key_auth={"user_id": "user789"} ) result = await access_control_hook.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now() ) assert result is not None assert result.should_proceed is False assert "not authorized" in result.error_message @pytest.mark.asyncio async def test_cost_tracking_hook(self, cost_tracking_hook): """Test cost tracking functionality""" kwargs = {"name": "github/create_issue"} response_obj = MCPPostCallResponseObject( mcp_tool_call_response=[], hidden_params=HiddenParams() ) result = await cost_tracking_hook.async_post_mcp_tool_call_hook( kwargs=kwargs, response_obj=response_obj, start_time=datetime.now(), end_time=datetime.now() ) assert result is not None assert result.hidden_params.response_cost == 0.10 assert cost_tracking_hook.call_count == 1 @pytest.mark.asyncio async def test_cost_tracking_hook_default_cost(self, cost_tracking_hook): """Test default cost assignment""" kwargs = {"name": "unknown_tool"} response_obj = MCPPostCallResponseObject( mcp_tool_call_response=[], hidden_params=HiddenParams() ) result = await cost_tracking_hook.async_post_mcp_tool_call_hook( kwargs=kwargs, response_obj=response_obj, start_time=datetime.now(), end_time=datetime.now() ) assert result is not None assert result.hidden_params.response_cost == 0.01 # Default cost @pytest.mark.asyncio async def test_monitoring_hook_normal_execution(self, monitoring_hook): """Test monitoring hook with normal execution time""" kwargs = {"name": "test_tool"} request_obj = MCPDuringCallRequestObject( tool_name="test_tool", arguments={}, start_time=datetime.now().timestamp() ) result = await monitoring_hook.async_during_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now() ) assert result is None # Should allow execution to continue assert monitoring_hook.call_count == 1 @pytest.mark.asyncio async def test_argument_validation_hook_valid_github_issue(self, argument_validation_hook): """Test argument validation for valid GitHub issue""" kwargs = {"name": "github/create_issue"} request_obj = MCPPreCallRequestObject( tool_name="github/create_issue", arguments={"title": "Valid issue title"} ) result = await argument_validation_hook.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now() ) assert result is not None assert result.should_proceed is True assert result.modified_arguments == {"title": "Valid issue title"} assert argument_validation_hook.call_count == 1 @pytest.mark.asyncio async def test_argument_validation_hook_missing_title(self, argument_validation_hook): """Test argument validation for missing GitHub issue title""" kwargs = {"name": "github/create_issue"} request_obj = MCPPreCallRequestObject( tool_name="github/create_issue", arguments={} # Missing title ) result = await argument_validation_hook.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now() ) assert result is not None assert result.should_proceed is False assert "title is required" in result.error_message @pytest.mark.asyncio async def test_argument_validation_hook_long_title_sanitization(self, argument_validation_hook): """Test argument validation with title sanitization""" kwargs = {"name": "github/create_issue"} long_title = "A" * 150 # Very long title request_obj = MCPPreCallRequestObject( tool_name="github/create_issue", arguments={"title": long_title} ) result = await argument_validation_hook.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now() ) assert result is not None assert result.should_proceed is True assert len(result.modified_arguments["title"]) == 100 # Truncated assert result.modified_arguments["title"].endswith("...") @pytest.mark.asyncio async def test_argument_validation_hook_email_validation(self, argument_validation_hook): """Test argument validation for email sending""" kwargs = {"name": "zapier/send_email"} request_obj = MCPPreCallRequestObject( tool_name="zapier/send_email", arguments={"to": "test@example.com", "subject": "Test"} ) result = await argument_validation_hook.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now() ) assert result is not None assert result.should_proceed is True assert result.modified_arguments == {"to": "test@example.com", "subject": "Test"} @pytest.mark.asyncio async def test_argument_validation_hook_missing_email_recipient(self, argument_validation_hook): """Test argument validation for missing email recipient""" kwargs = {"name": "zapier/send_email"} request_obj = MCPPreCallRequestObject( tool_name="zapier/send_email", arguments={"subject": "Test"} # Missing 'to' field ) result = await argument_validation_hook.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now() ) assert result is not None assert result.should_proceed is False assert "recipient is required" in result.error_message # Integration test class TestMCPHookIntegration: """Integration tests for MCP hook system""" @pytest.mark.asyncio async def test_hook_chain_execution(self): """Test that multiple hooks can work together""" access_hook = TestMCPAccessControlHook() cost_hook = TestMCPCostTrackingHook() validation_hook = TestMCPArgumentValidationHook() # Test data kwargs = { "user_api_key_auth": {"user_id": "user789"}, "name": "github/create_issue" } request_obj = MCPPreCallRequestObject( tool_name="github/create_issue", arguments={"title": "Integration test issue"}, user_api_key_auth={"user_id": "user789"} ) # Execute pre-hooks access_result = await access_hook.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now() ) validation_result = await validation_hook.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now() ) # Both hooks should allow execution assert access_result is None assert validation_result is not None assert validation_result.should_proceed is True # Simulate post-hook execution response_obj = MCPPostCallResponseObject( mcp_tool_call_response=[], hidden_params=HiddenParams() ) cost_result = await cost_hook.async_post_mcp_tool_call_hook( kwargs=kwargs, response_obj=response_obj, start_time=datetime.now(), end_time=datetime.now() ) assert cost_result is not None assert cost_result.hidden_params.response_cost == 0.10 if __name__ == "__main__": # Run the tests pytest.main([__file__, "-v"])