Files
Homelab/Development/litellm/tests/mcp_tests/test_mcp_hooks.py

489 lines
16 KiB
Python

"""
Test file for MCP Hook Architecture
This file demonstrates the new MCP hook system with comprehensive examples
and validation tests.
"""
import asyncio
import pytest
from datetime import datetime
from typing import Optional
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.mcp import (
MCPPreCallRequestObject,
MCPPreCallResponseObject,
MCPDuringCallRequestObject,
MCPDuringCallResponseObject,
MCPPostCallResponseObject,
)
from litellm.types.llms.base import HiddenParams
class TestMCPAccessControlHook(CustomLogger):
"""Test hook for access control functionality"""
def __init__(self):
self.allowed_tools = {"github/create_issue", "zapier/send_email"}
self.blocked_users = {"user123", "user456"}
self.call_count = 0
async def async_pre_mcp_tool_call_hook(
self,
kwargs,
request_obj: MCPPreCallRequestObject,
start_time,
end_time
) -> Optional[MCPPreCallResponseObject]:
"""Test access control validation"""
self.call_count += 1
tool_name = request_obj.tool_name
user_id = kwargs.get("user_api_key_auth", {}).get("user_id")
# Check if user is blocked
if user_id in self.blocked_users:
return MCPPreCallResponseObject(
should_proceed=False,
error_message=f"User {user_id} is not authorized to use MCP tools"
)
# Check if tool is allowed
if tool_name not in self.allowed_tools:
return MCPPreCallResponseObject(
should_proceed=False,
error_message=f"Tool {tool_name} is not authorized"
)
return None # Allow execution to proceed
class TestMCPCostTrackingHook(CustomLogger):
"""Test hook for cost tracking functionality"""
def __init__(self):
self.cost_map = {
"github/create_issue": 0.10,
"zapier/send_email": 0.05,
"default": 0.01
}
self.call_count = 0
async def async_post_mcp_tool_call_hook(
self,
kwargs,
response_obj: MCPPostCallResponseObject,
start_time,
end_time
) -> Optional[MCPPostCallResponseObject]:
"""Test cost calculation after tool execution"""
self.call_count += 1
tool_name = kwargs.get("name", "")
cost = self.cost_map.get(tool_name, self.cost_map["default"])
# Set the response cost
response_obj.hidden_params.response_cost = cost
return response_obj
class TestMCPMonitoringHook(CustomLogger):
"""Test hook for real-time monitoring functionality"""
def __init__(self):
self.max_execution_time = 30.0 # seconds
self.call_count = 0
async def async_during_mcp_tool_call_hook(
self,
kwargs,
request_obj: MCPDuringCallRequestObject,
start_time,
end_time
) -> Optional[MCPDuringCallResponseObject]:
"""Test execution time monitoring"""
self.call_count += 1
tool_name = request_obj.tool_name
execution_time = (datetime.now() - start_time).total_seconds()
# Check if execution is taking too long
if execution_time > self.max_execution_time:
return MCPDuringCallResponseObject(
should_continue=False,
error_message=f"Tool {tool_name} execution timeout after {execution_time}s"
)
return None # Allow execution to continue
class TestMCPArgumentValidationHook(CustomLogger):
"""Test hook for argument validation functionality"""
def __init__(self):
self.call_count = 0
async def async_pre_mcp_tool_call_hook(
self,
kwargs,
request_obj: MCPPreCallRequestObject,
start_time,
end_time
) -> Optional[MCPPreCallResponseObject]:
"""Test argument validation and sanitization"""
self.call_count += 1
tool_name = request_obj.tool_name
arguments = request_obj.arguments.copy() # Create a copy to modify
# Example: Validate GitHub issue creation
if tool_name == "github/create_issue":
if not arguments.get("title"):
return MCPPreCallResponseObject(
should_proceed=False,
error_message="GitHub issue title is required"
)
# Sanitize the title
title = arguments["title"]
if len(title) > 100:
title = title[:97] + "..."
arguments["title"] = title
# Example: Validate email sending
elif tool_name == "zapier/send_email":
if not arguments.get("to"):
return MCPPreCallResponseObject(
should_proceed=False,
error_message="Email recipient is required"
)
return MCPPreCallResponseObject(
should_proceed=True,
modified_arguments=arguments
)
# Test fixtures
@pytest.fixture
def access_control_hook():
return TestMCPAccessControlHook()
@pytest.fixture
def cost_tracking_hook():
return TestMCPCostTrackingHook()
@pytest.fixture
def monitoring_hook():
return TestMCPMonitoringHook()
@pytest.fixture
def argument_validation_hook():
return TestMCPArgumentValidationHook()
# Test cases
class TestMCPHooks:
"""Test cases for MCP hook functionality"""
@pytest.mark.asyncio
async def test_access_control_hook_allowed_tool(self, access_control_hook):
"""Test that allowed tools pass validation"""
kwargs = {
"user_api_key_auth": {"user_id": "user789"},
"name": "github/create_issue"
}
request_obj = MCPPreCallRequestObject(
tool_name="github/create_issue",
arguments={"title": "Test issue"},
user_api_key_auth={"user_id": "user789"}
)
result = await access_control_hook.async_pre_mcp_tool_call_hook(
kwargs=kwargs,
request_obj=request_obj,
start_time=datetime.now(),
end_time=datetime.now()
)
assert result is None # Should allow execution
assert access_control_hook.call_count == 1
@pytest.mark.asyncio
async def test_access_control_hook_blocked_user(self, access_control_hook):
"""Test that blocked users are rejected"""
kwargs = {
"user_api_key_auth": {"user_id": "user123"},
"name": "github/create_issue"
}
request_obj = MCPPreCallRequestObject(
tool_name="github/create_issue",
arguments={"title": "Test issue"},
user_api_key_auth={"user_id": "user123"}
)
result = await access_control_hook.async_pre_mcp_tool_call_hook(
kwargs=kwargs,
request_obj=request_obj,
start_time=datetime.now(),
end_time=datetime.now()
)
assert result is not None
assert result.should_proceed is False
assert "not authorized" in result.error_message
@pytest.mark.asyncio
async def test_access_control_hook_unauthorized_tool(self, access_control_hook):
"""Test that unauthorized tools are rejected"""
kwargs = {
"user_api_key_auth": {"user_id": "user789"},
"name": "unauthorized_tool"
}
request_obj = MCPPreCallRequestObject(
tool_name="unauthorized_tool",
arguments={"param": "value"},
user_api_key_auth={"user_id": "user789"}
)
result = await access_control_hook.async_pre_mcp_tool_call_hook(
kwargs=kwargs,
request_obj=request_obj,
start_time=datetime.now(),
end_time=datetime.now()
)
assert result is not None
assert result.should_proceed is False
assert "not authorized" in result.error_message
@pytest.mark.asyncio
async def test_cost_tracking_hook(self, cost_tracking_hook):
"""Test cost tracking functionality"""
kwargs = {"name": "github/create_issue"}
response_obj = MCPPostCallResponseObject(
mcp_tool_call_response=[],
hidden_params=HiddenParams()
)
result = await cost_tracking_hook.async_post_mcp_tool_call_hook(
kwargs=kwargs,
response_obj=response_obj,
start_time=datetime.now(),
end_time=datetime.now()
)
assert result is not None
assert result.hidden_params.response_cost == 0.10
assert cost_tracking_hook.call_count == 1
@pytest.mark.asyncio
async def test_cost_tracking_hook_default_cost(self, cost_tracking_hook):
"""Test default cost assignment"""
kwargs = {"name": "unknown_tool"}
response_obj = MCPPostCallResponseObject(
mcp_tool_call_response=[],
hidden_params=HiddenParams()
)
result = await cost_tracking_hook.async_post_mcp_tool_call_hook(
kwargs=kwargs,
response_obj=response_obj,
start_time=datetime.now(),
end_time=datetime.now()
)
assert result is not None
assert result.hidden_params.response_cost == 0.01 # Default cost
@pytest.mark.asyncio
async def test_monitoring_hook_normal_execution(self, monitoring_hook):
"""Test monitoring hook with normal execution time"""
kwargs = {"name": "test_tool"}
request_obj = MCPDuringCallRequestObject(
tool_name="test_tool",
arguments={},
start_time=datetime.now().timestamp()
)
result = await monitoring_hook.async_during_mcp_tool_call_hook(
kwargs=kwargs,
request_obj=request_obj,
start_time=datetime.now(),
end_time=datetime.now()
)
assert result is None # Should allow execution to continue
assert monitoring_hook.call_count == 1
@pytest.mark.asyncio
async def test_argument_validation_hook_valid_github_issue(self, argument_validation_hook):
"""Test argument validation for valid GitHub issue"""
kwargs = {"name": "github/create_issue"}
request_obj = MCPPreCallRequestObject(
tool_name="github/create_issue",
arguments={"title": "Valid issue title"}
)
result = await argument_validation_hook.async_pre_mcp_tool_call_hook(
kwargs=kwargs,
request_obj=request_obj,
start_time=datetime.now(),
end_time=datetime.now()
)
assert result is not None
assert result.should_proceed is True
assert result.modified_arguments == {"title": "Valid issue title"}
assert argument_validation_hook.call_count == 1
@pytest.mark.asyncio
async def test_argument_validation_hook_missing_title(self, argument_validation_hook):
"""Test argument validation for missing GitHub issue title"""
kwargs = {"name": "github/create_issue"}
request_obj = MCPPreCallRequestObject(
tool_name="github/create_issue",
arguments={} # Missing title
)
result = await argument_validation_hook.async_pre_mcp_tool_call_hook(
kwargs=kwargs,
request_obj=request_obj,
start_time=datetime.now(),
end_time=datetime.now()
)
assert result is not None
assert result.should_proceed is False
assert "title is required" in result.error_message
@pytest.mark.asyncio
async def test_argument_validation_hook_long_title_sanitization(self, argument_validation_hook):
"""Test argument validation with title sanitization"""
kwargs = {"name": "github/create_issue"}
long_title = "A" * 150 # Very long title
request_obj = MCPPreCallRequestObject(
tool_name="github/create_issue",
arguments={"title": long_title}
)
result = await argument_validation_hook.async_pre_mcp_tool_call_hook(
kwargs=kwargs,
request_obj=request_obj,
start_time=datetime.now(),
end_time=datetime.now()
)
assert result is not None
assert result.should_proceed is True
assert len(result.modified_arguments["title"]) == 100 # Truncated
assert result.modified_arguments["title"].endswith("...")
@pytest.mark.asyncio
async def test_argument_validation_hook_email_validation(self, argument_validation_hook):
"""Test argument validation for email sending"""
kwargs = {"name": "zapier/send_email"}
request_obj = MCPPreCallRequestObject(
tool_name="zapier/send_email",
arguments={"to": "test@example.com", "subject": "Test"}
)
result = await argument_validation_hook.async_pre_mcp_tool_call_hook(
kwargs=kwargs,
request_obj=request_obj,
start_time=datetime.now(),
end_time=datetime.now()
)
assert result is not None
assert result.should_proceed is True
assert result.modified_arguments == {"to": "test@example.com", "subject": "Test"}
@pytest.mark.asyncio
async def test_argument_validation_hook_missing_email_recipient(self, argument_validation_hook):
"""Test argument validation for missing email recipient"""
kwargs = {"name": "zapier/send_email"}
request_obj = MCPPreCallRequestObject(
tool_name="zapier/send_email",
arguments={"subject": "Test"} # Missing 'to' field
)
result = await argument_validation_hook.async_pre_mcp_tool_call_hook(
kwargs=kwargs,
request_obj=request_obj,
start_time=datetime.now(),
end_time=datetime.now()
)
assert result is not None
assert result.should_proceed is False
assert "recipient is required" in result.error_message
# Integration test
class TestMCPHookIntegration:
"""Integration tests for MCP hook system"""
@pytest.mark.asyncio
async def test_hook_chain_execution(self):
"""Test that multiple hooks can work together"""
access_hook = TestMCPAccessControlHook()
cost_hook = TestMCPCostTrackingHook()
validation_hook = TestMCPArgumentValidationHook()
# Test data
kwargs = {
"user_api_key_auth": {"user_id": "user789"},
"name": "github/create_issue"
}
request_obj = MCPPreCallRequestObject(
tool_name="github/create_issue",
arguments={"title": "Integration test issue"},
user_api_key_auth={"user_id": "user789"}
)
# Execute pre-hooks
access_result = await access_hook.async_pre_mcp_tool_call_hook(
kwargs=kwargs,
request_obj=request_obj,
start_time=datetime.now(),
end_time=datetime.now()
)
validation_result = await validation_hook.async_pre_mcp_tool_call_hook(
kwargs=kwargs,
request_obj=request_obj,
start_time=datetime.now(),
end_time=datetime.now()
)
# Both hooks should allow execution
assert access_result is None
assert validation_result is not None
assert validation_result.should_proceed is True
# Simulate post-hook execution
response_obj = MCPPostCallResponseObject(
mcp_tool_call_response=[],
hidden_params=HiddenParams()
)
cost_result = await cost_hook.async_post_mcp_tool_call_hook(
kwargs=kwargs,
response_obj=response_obj,
start_time=datetime.now(),
end_time=datetime.now()
)
assert cost_result is not None
assert cost_result.hidden_params.response_cost == 0.10
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])