""" Test file for MCP Guardrails Feature This file tests the MCP guardrails functionality for both pre and during MCP call hooks, including various guardrail types and proper exception handling. """ import asyncio import pytest import sys import os from datetime import datetime from typing import Optional, Dict, Any from unittest.mock import MagicMock, AsyncMock, patch # Add the project root to the path sys.path.insert(0, os.path.abspath("../..")) import litellm from litellm.exceptions import BlockedPiiEntityError, GuardrailRaisedException from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth from litellm.caching.caching import DualCache from litellm.types.mcp import ( MCPPreCallRequestObject, MCPPreCallResponseObject, MCPDuringCallRequestObject, MCPDuringCallResponseObject, ) from litellm.types.llms.base import HiddenParams from litellm.types.guardrails import GuardrailEventHooks from fastapi import HTTPException class MockPiiGuardrail(CustomGuardrail): """Mock PII guardrail that raises BlockedPiiEntityError""" def __init__(self, should_block: bool = True, entity_type: str = "EMAIL_ADDRESS"): super().__init__() self.should_block = should_block self.entity_type = entity_type self.guardrail_name = "mock-pii-guardrail" self.call_count = 0 def should_run_guardrail(self, data: dict, event_type: GuardrailEventHooks) -> bool: """Always run for testing""" return True async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str, ): """Mock pre-call hook that raises BlockedPiiEntityError""" self.call_count += 1 if self.should_block: raise BlockedPiiEntityError( entity_type=self.entity_type, guardrail_name=self.guardrail_name, ) return None class MockContentGuardrail(CustomGuardrail): """Mock content guardrail that raises GuardrailRaisedException""" def __init__(self, should_block: bool = True): super().__init__() self.should_block = should_block self.guardrail_name = "mock-content-guardrail" self.call_count = 0 def should_run_guardrail(self, data: dict, event_type: GuardrailEventHooks) -> bool: """Always run for testing""" return True async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str, ): """Mock pre-call hook that raises GuardrailRaisedException""" self.call_count += 1 if self.should_block: raise GuardrailRaisedException( guardrail_name=self.guardrail_name, message="Content violates policy" ) return None class MockHttpGuardrail(CustomGuardrail): """Mock HTTP guardrail that raises HTTPException""" def __init__(self, should_block: bool = True): super().__init__() self.should_block = should_block self.guardrail_name = "mock-http-guardrail" self.call_count = 0 def should_run_guardrail(self, data: dict, event_type: GuardrailEventHooks) -> bool: """Always run for testing""" return True async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str, ): """Mock pre-call hook that raises HTTPException""" self.call_count += 1 if self.should_block: raise HTTPException( status_code=400, detail={"error": "Violated guardrail policy"} ) return None class MockDuringCallGuardrail(CustomGuardrail): """Mock guardrail for during-call testing""" def __init__(self, should_block: bool = True): super().__init__() self.should_block = should_block self.guardrail_name = "mock-during-guardrail" self.call_count = 0 def should_run_guardrail(self, data: dict, event_type: GuardrailEventHooks) -> bool: """Always run for testing""" return True async def async_moderation_hook( self, data: dict, user_api_key_dict: UserAPIKeyAuth, call_type: str, ): """Mock during-call hook that raises exceptions""" self.call_count += 1 if self.should_block: raise BlockedPiiEntityError( entity_type="PHONE_NUMBER", guardrail_name=self.guardrail_name, ) return None class MockProxyLogging: """Mock proxy logging object for testing MCP guardrails""" def __init__(self, guardrails: Optional[list] = None): self.guardrails = guardrails if guardrails is not None else [] self.call_details = {"user_api_key_cache": DualCache()} self.dynamic_success_callbacks = [] self.call_count = 0 def get_combined_callback_list(self, dynamic_success_callbacks, global_callbacks): """Return the guardrails for testing""" return self.guardrails def _convert_mcp_to_llm_format(self, request_obj, kwargs: dict) -> dict: """Convert MCP tool call to LLM message format""" tool_call_content = f"Tool: {request_obj.tool_name}\nArguments: {request_obj.arguments}" return { "messages": [{"role": "user", "content": tool_call_content}], "model": kwargs.get("model", "mcp-tool-call"), "user_api_key_user_id": kwargs.get("user_api_key_user_id"), "user_api_key_team_id": kwargs.get("user_api_key_team_id"), } def _convert_llm_result_to_mcp_response(self, llm_result, request_obj): """Convert LLM result back to MCP response format""" return None # For testing, we don't need to convert back def _parse_pre_mcp_call_hook_response(self, response, original_request): """Parse pre MCP call hook response""" return response async def async_pre_mcp_tool_call_hook( self, kwargs: dict, request_obj: Any, start_time: datetime, end_time: datetime, ) -> Optional[Any]: """Mock pre MCP tool call hook""" self.call_count += 1 # Simulate the actual hook logic for guardrail in self.guardrails: if isinstance(guardrail, CustomGuardrail): try: synthetic_data = self._convert_mcp_to_llm_format(request_obj, kwargs) # Check if guardrail should run if not guardrail.should_run_guardrail(synthetic_data, GuardrailEventHooks.pre_mcp_call): continue result = await guardrail.async_pre_call_hook( user_api_key_dict=kwargs.get("user_api_key_auth"), cache=self.call_details["user_api_key_cache"], data=synthetic_data, call_type="mcp_call" ) if result is not None: return self._parse_pre_mcp_call_hook_response(result, request_obj) except (BlockedPiiEntityError, GuardrailRaisedException, HTTPException) as e: # Re-raise guardrail exceptions raise e except Exception as e: # Log non-guardrail exceptions as non-blocking print(f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {str(e)}") return None async def async_during_mcp_tool_call_hook( self, kwargs: dict, request_obj: Any, start_time: datetime, end_time: datetime, ) -> Optional[Any]: """Mock during MCP tool call hook""" self.call_count += 1 # Simulate the actual hook logic for guardrail in self.guardrails: if isinstance(guardrail, CustomGuardrail): try: synthetic_data = self._convert_mcp_to_llm_format(request_obj, kwargs) result = await guardrail.async_moderation_hook( data=synthetic_data, user_api_key_dict=kwargs.get("user_api_key_auth"), call_type="mcp_call" ) if result is not None: return result except (BlockedPiiEntityError, GuardrailRaisedException, HTTPException) as e: # Re-raise guardrail exceptions raise e except Exception as e: # Log non-guardrail exceptions as non-blocking print(f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {str(e)}") return None @pytest.fixture def mock_user_api_key(): """Mock user API key for testing""" return UserAPIKeyAuth(api_key="test_key", user_id="test_user") @pytest.fixture def mock_cache(): """Mock cache for testing""" return DualCache() @pytest.fixture def mock_pii_guardrail(): """Mock PII guardrail that blocks""" return MockPiiGuardrail(should_block=True) @pytest.fixture def mock_pii_guardrail_allow(): """Mock PII guardrail that allows""" return MockPiiGuardrail(should_block=False) @pytest.fixture def mock_content_guardrail(): """Mock content guardrail that blocks""" return MockContentGuardrail(should_block=True) @pytest.fixture def mock_http_guardrail(): """Mock HTTP guardrail that blocks""" return MockHttpGuardrail(should_block=True) @pytest.fixture def mock_during_guardrail(): """Mock during-call guardrail that blocks""" return MockDuringCallGuardrail(should_block=True) @pytest.fixture def mock_proxy_logging(): """Mock proxy logging object""" return MockProxyLogging() class TestMCPGuardrailsPreCall: """Test MCP guardrails for pre-call hooks""" @pytest.mark.asyncio async def test_pii_guardrail_blocks_pre_call(self, mock_pii_guardrail, mock_user_api_key, mock_cache): """Test that PII guardrail properly blocks pre-call""" proxy_logging = MockProxyLogging([mock_pii_guardrail]) # Create MCP request request_obj = MCPPreCallRequestObject( tool_name="email_tool", arguments={"email": "test@example.com"}, server_name="email_server", user_api_key_auth=mock_user_api_key.model_dump(), hidden_params=HiddenParams() ) kwargs = { "name": "email_tool", "arguments": {"email": "test@example.com"}, "server_name": "email_server", "user_api_key_auth": mock_user_api_key, } # Test that BlockedPiiEntityError is raised with pytest.raises(BlockedPiiEntityError) as excinfo: await proxy_logging.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now(), ) # Verify the error details assert excinfo.value.entity_type == "EMAIL_ADDRESS" assert excinfo.value.guardrail_name == "mock-pii-guardrail" assert mock_pii_guardrail.call_count == 1 @pytest.mark.asyncio async def test_pii_guardrail_allows_pre_call(self, mock_pii_guardrail_allow, mock_user_api_key, mock_cache): """Test that PII guardrail allows pre-call when configured to allow""" proxy_logging = MockProxyLogging([mock_pii_guardrail_allow]) request_obj = MCPPreCallRequestObject( tool_name="email_tool", arguments={"email": "test@example.com"}, server_name="email_server", user_api_key_auth=mock_user_api_key.model_dump(), hidden_params=HiddenParams() ) kwargs = { "name": "email_tool", "arguments": {"email": "test@example.com"}, "server_name": "email_server", "user_api_key_auth": mock_user_api_key, } # Test that no exception is raised result = await proxy_logging.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now(), ) assert result is None assert mock_pii_guardrail_allow.call_count == 1 @pytest.mark.asyncio async def test_content_guardrail_blocks_pre_call(self, mock_content_guardrail, mock_user_api_key, mock_cache): """Test that content guardrail properly blocks pre-call""" proxy_logging = MockProxyLogging([mock_content_guardrail]) request_obj = MCPPreCallRequestObject( tool_name="content_tool", arguments={"content": "sensitive content"}, server_name="content_server", user_api_key_auth=mock_user_api_key.model_dump(), hidden_params=HiddenParams() ) kwargs = { "name": "content_tool", "arguments": {"content": "sensitive content"}, "server_name": "content_server", "user_api_key_auth": mock_user_api_key, } # Test that GuardrailRaisedException is raised with pytest.raises(GuardrailRaisedException) as excinfo: await proxy_logging.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now(), ) # Verify the error details assert "Content violates policy" in str(excinfo.value) assert excinfo.value.guardrail_name == "mock-content-guardrail" assert mock_content_guardrail.call_count == 1 @pytest.mark.asyncio async def test_http_guardrail_blocks_pre_call(self, mock_http_guardrail, mock_user_api_key, mock_cache): """Test that HTTP guardrail properly blocks pre-call""" proxy_logging = MockProxyLogging([mock_http_guardrail]) request_obj = MCPPreCallRequestObject( tool_name="http_tool", arguments={"url": "http://example.com"}, server_name="http_server", user_api_key_auth=mock_user_api_key.model_dump(), hidden_params=HiddenParams() ) kwargs = { "name": "http_tool", "arguments": {"url": "http://example.com"}, "server_name": "http_server", "user_api_key_auth": mock_user_api_key, } # Test that HTTPException is raised with pytest.raises(HTTPException) as excinfo: await proxy_logging.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now(), ) # Verify the error details assert excinfo.value.status_code == 400 assert "Violated guardrail policy" in str(excinfo.value.detail) assert mock_http_guardrail.call_count == 1 @pytest.mark.asyncio async def test_multiple_guardrails_pre_call(self, mock_pii_guardrail, mock_content_guardrail, mock_user_api_key, mock_cache): """Test multiple guardrails - first one should block""" proxy_logging = MockProxyLogging([mock_pii_guardrail, mock_content_guardrail]) request_obj = MCPPreCallRequestObject( tool_name="test_tool", arguments={"email": "test@example.com"}, server_name="test_server", user_api_key_auth=mock_user_api_key.model_dump(), hidden_params=HiddenParams() ) kwargs = { "name": "test_tool", "arguments": {"email": "test@example.com"}, "server_name": "test_server", "user_api_key_auth": mock_user_api_key, } # Test that first guardrail blocks with pytest.raises(BlockedPiiEntityError): await proxy_logging.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now(), ) # Verify only first guardrail was called assert mock_pii_guardrail.call_count == 1 assert mock_content_guardrail.call_count == 0 class TestMCPGuardrailsDuringCall: """Test MCP guardrails for during-call hooks""" @pytest.mark.asyncio async def test_during_call_guardrail_blocks(self, mock_during_guardrail, mock_user_api_key, mock_cache): """Test that during-call guardrail properly blocks execution""" proxy_logging = MockProxyLogging([mock_during_guardrail]) request_obj = MCPDuringCallRequestObject( tool_name="phone_tool", arguments={"phone": "555-123-4567"}, server_name="phone_server", start_time=datetime.now().timestamp(), hidden_params=HiddenParams() ) kwargs = { "name": "phone_tool", "arguments": {"phone": "555-123-4567"}, "server_name": "phone_server", } # Test that BlockedPiiEntityError is raised with pytest.raises(BlockedPiiEntityError) as excinfo: await proxy_logging.async_during_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now(), ) # Verify the error details assert excinfo.value.entity_type == "PHONE_NUMBER" assert excinfo.value.guardrail_name == "mock-during-guardrail" assert mock_during_guardrail.call_count == 1 class TestMCPGuardrailsIntegration: """Test MCP guardrails integration with MCP server manager""" @pytest.mark.asyncio async def test_mcp_server_manager_with_guardrails(self): """Test MCP server manager with guardrail integration""" mock_proxy_logging = MockProxyLogging([MockPiiGuardrail(should_block=True)]) # Test that guardrail exception is properly raised in the hook with pytest.raises(BlockedPiiEntityError): await mock_proxy_logging.async_pre_mcp_tool_call_hook( kwargs={"name": "email_tool", "arguments": {"email": "test@example.com"}}, request_obj=MagicMock(), start_time=datetime.now(), end_time=datetime.now(), ) @pytest.mark.asyncio async def test_guardrail_exception_propagation(self): """Test that guardrail exceptions properly propagate through the system""" # Test BlockedPiiEntityError with pytest.raises(BlockedPiiEntityError): raise BlockedPiiEntityError( entity_type="EMAIL_ADDRESS", guardrail_name="test-guardrail" ) # Test GuardrailRaisedException with pytest.raises(GuardrailRaisedException): raise GuardrailRaisedException( guardrail_name="test-guardrail", message="Test message" ) # Test HTTPException with pytest.raises(HTTPException): raise HTTPException( status_code=400, detail={"error": "Test error"} ) class TestMCPGuardrailsErrorHandling: """Test MCP guardrails error handling scenarios""" @pytest.mark.asyncio async def test_non_guardrail_exception_logging(self, mock_user_api_key, mock_cache): """Test that non-guardrail exceptions are logged as non-blocking""" class MockFailingGuardrail(CustomGuardrail): def should_run_guardrail(self, data: dict, event_type: GuardrailEventHooks) -> bool: return True async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str, ): raise Exception("Non-guardrail error") proxy_logging = MockProxyLogging([MockFailingGuardrail()]) request_obj = MCPPreCallRequestObject( tool_name="test_tool", arguments={"test": "data"}, server_name="test_server", user_api_key_auth=mock_user_api_key.model_dump(), hidden_params=HiddenParams() ) kwargs = { "name": "test_tool", "arguments": {"test": "data"}, "server_name": "test_server", "user_api_key_auth": mock_user_api_key, } # Test that non-guardrail exceptions are handled gracefully result = await proxy_logging.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now(), ) # Should return None (not raise exception) assert result is None @pytest.mark.asyncio async def test_guardrail_should_not_run(self, mock_user_api_key, mock_cache): """Test that guardrails don't run when should_run_guardrail returns False""" class MockConditionalGuardrail(CustomGuardrail): def should_run_guardrail(self, data: dict, event_type: GuardrailEventHooks) -> bool: return False # Don't run async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str, ): raise BlockedPiiEntityError("EMAIL_ADDRESS", "test-guardrail") proxy_logging = MockProxyLogging([MockConditionalGuardrail()]) request_obj = MCPPreCallRequestObject( tool_name="test_tool", arguments={"test": "data"}, server_name="test_server", user_api_key_auth=mock_user_api_key.model_dump(), hidden_params=HiddenParams() ) kwargs = { "name": "test_tool", "arguments": {"test": "data"}, "server_name": "test_server", "user_api_key_auth": mock_user_api_key, } # Test that guardrail doesn't run and no exception is raised result = await proxy_logging.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now(), ) # Should return None (guardrail didn't run) assert result is None class TestMCPGuardrailsEdgeCases: """Test MCP guardrails edge cases and error conditions""" @pytest.mark.asyncio async def test_empty_guardrails_list(self, mock_user_api_key, mock_cache): """Test behavior with empty guardrails list""" proxy_logging = MockProxyLogging([]) # No guardrails request_obj = MCPPreCallRequestObject( tool_name="test_tool", arguments={"test": "data"}, server_name="test_server", user_api_key_auth=mock_user_api_key.model_dump(), hidden_params=HiddenParams() ) kwargs = { "name": "test_tool", "arguments": {"test": "data"}, "server_name": "test_server", "user_api_key_auth": mock_user_api_key, } # Should return None without any issues result = await proxy_logging.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now(), ) assert result is None @pytest.mark.asyncio async def test_guardrail_with_invalid_data(self, mock_user_api_key, mock_cache): """Test guardrail behavior with invalid data""" class MockInvalidDataGuardrail(CustomGuardrail): def should_run_guardrail(self, data: dict, event_type: GuardrailEventHooks) -> bool: return True async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str, ): # Try to access invalid data invalid_data = data.get("invalid_key", {}) if invalid_data.get("should_fail"): raise BlockedPiiEntityError("EMAIL_ADDRESS", "test-guardrail") return None proxy_logging = MockProxyLogging([MockInvalidDataGuardrail()]) request_obj = MCPPreCallRequestObject( tool_name="test_tool", arguments={"test": "data"}, server_name="test_server", user_api_key_auth=mock_user_api_key.model_dump(), hidden_params=HiddenParams() ) kwargs = { "name": "test_tool", "arguments": {"test": "data"}, "server_name": "test_server", "user_api_key_auth": mock_user_api_key, } # Should handle invalid data gracefully result = await proxy_logging.async_pre_mcp_tool_call_hook( kwargs=kwargs, request_obj=request_obj, start_time=datetime.now(), end_time=datetime.now(), ) assert result is None if __name__ == "__main__": pytest.main([__file__])