180 lines
7.5 KiB
Python
180 lines
7.5 KiB
Python
import sys
|
|
import os
|
|
import io, asyncio
|
|
import json
|
|
import pytest
|
|
import time
|
|
from litellm import mock_completion
|
|
from unittest.mock import MagicMock, AsyncMock, patch
|
|
sys.path.insert(0, os.path.abspath("../.."))
|
|
import litellm
|
|
from litellm.proxy.guardrails.guardrail_hooks.presidio import _OPTIONAL_PresidioPIIMasking, PresidioPerRequestConfig
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from litellm.types.utils import StandardLoggingPayload, StandardLoggingGuardrailInformation
|
|
from litellm.types.guardrails import GuardrailEventHooks
|
|
from typing import Optional
|
|
|
|
|
|
class TestCustomLogger(CustomLogger):
|
|
def __init__(self, *args, **kwargs):
|
|
self.standard_logging_payload: Optional[StandardLoggingPayload] = None
|
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
self.standard_logging_payload = kwargs.get("standard_logging_object")
|
|
pass
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_standard_logging_payload_includes_guardrail_information():
|
|
"""
|
|
Test that the standard logging payload includes the guardrail information when a guardrail is applied
|
|
"""
|
|
test_custom_logger = TestCustomLogger()
|
|
litellm.callbacks = [test_custom_logger]
|
|
presidio_guard = _OPTIONAL_PresidioPIIMasking(
|
|
guardrail_name="presidio_guard",
|
|
event_hook=GuardrailEventHooks.pre_call,
|
|
presidio_analyzer_api_base=os.getenv("PRESIDIO_ANALYZER_API_BASE"),
|
|
presidio_anonymizer_api_base=os.getenv("PRESIDIO_ANONYMIZER_API_BASE"),
|
|
)
|
|
# 1. call the pre call hook with guardrail
|
|
request_data = {
|
|
"model": "gpt-4o",
|
|
"messages": [
|
|
{"role": "user", "content": "Hello, my phone number is +1 412 555 1212"},
|
|
],
|
|
"mock_response": "Hello",
|
|
"guardrails": ["presidio_guard"],
|
|
"metadata": {},
|
|
}
|
|
await presidio_guard.async_pre_call_hook(
|
|
user_api_key_dict={},
|
|
cache=None,
|
|
data=request_data,
|
|
call_type="acompletion"
|
|
)
|
|
|
|
# 2. call litellm.acompletion
|
|
response = await litellm.acompletion(**request_data)
|
|
|
|
# 3. assert that the standard logging payload includes the guardrail information
|
|
await asyncio.sleep(1)
|
|
print("got standard logging payload=", json.dumps(test_custom_logger.standard_logging_payload, indent=4, default=str))
|
|
assert test_custom_logger.standard_logging_payload is not None
|
|
assert test_custom_logger.standard_logging_payload["guardrail_information"] is not None
|
|
assert test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_name"] == "presidio_guard"
|
|
assert test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_mode"] == GuardrailEventHooks.pre_call
|
|
|
|
# assert that the guardrail_response is a response from presidio analyze
|
|
presidio_response = test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_response"]
|
|
assert isinstance(presidio_response, list)
|
|
for response_item in presidio_response:
|
|
assert "analysis_explanation" in response_item
|
|
assert "start" in response_item
|
|
assert "end" in response_item
|
|
assert "score" in response_item
|
|
assert "entity_type" in response_item
|
|
|
|
# assert that the duration is not None
|
|
assert test_custom_logger.standard_logging_payload["guardrail_information"]["duration"] is not None
|
|
assert test_custom_logger.standard_logging_payload["guardrail_information"]["duration"] > 0
|
|
|
|
# assert that we get the count of masked entities
|
|
assert test_custom_logger.standard_logging_payload["guardrail_information"]["masked_entity_count"] is not None
|
|
assert test_custom_logger.standard_logging_payload["guardrail_information"]["masked_entity_count"]["PHONE_NUMBER"] == 1
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Local only test")
|
|
async def test_langfuse_trace_includes_guardrail_information():
|
|
"""
|
|
Test that the langfuse trace includes the guardrail information when a guardrail is applied
|
|
"""
|
|
import httpx
|
|
from unittest.mock import AsyncMock, patch
|
|
from litellm.integrations.langfuse.langfuse_prompt_management import LangfusePromptManagement
|
|
callback = LangfusePromptManagement(flush_interval=3)
|
|
import json
|
|
|
|
# Create a mock Response object
|
|
mock_response = AsyncMock(spec=httpx.Response)
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {"status": "success"}
|
|
|
|
# Create mock for httpx.Client.post
|
|
mock_post = AsyncMock()
|
|
mock_post.return_value = mock_response
|
|
|
|
with patch("httpx.Client.post", mock_post):
|
|
litellm._turn_on_debug()
|
|
litellm.callbacks = [callback]
|
|
presidio_guard = _OPTIONAL_PresidioPIIMasking(
|
|
guardrail_name="presidio_guard",
|
|
event_hook=GuardrailEventHooks.pre_call,
|
|
presidio_analyzer_api_base=os.getenv("PRESIDIO_ANALYZER_API_BASE"),
|
|
presidio_anonymizer_api_base=os.getenv("PRESIDIO_ANONYMIZER_API_BASE"),
|
|
)
|
|
# 1. call the pre call hook with guardrail
|
|
request_data = {
|
|
"model": "gpt-4o",
|
|
"messages": [
|
|
{"role": "user", "content": "Hello, my phone number is +1 412 555 1212"},
|
|
],
|
|
"mock_response": "Hello",
|
|
"guardrails": ["presidio_guard"],
|
|
"metadata": {},
|
|
}
|
|
await presidio_guard.async_pre_call_hook(
|
|
user_api_key_dict={},
|
|
cache=None,
|
|
data=request_data,
|
|
call_type="acompletion"
|
|
)
|
|
|
|
# 2. call litellm.acompletion
|
|
response = await litellm.acompletion(**request_data)
|
|
|
|
# 3. Wait for async logging operations to complete
|
|
await asyncio.sleep(5)
|
|
|
|
# 4. Verify the Langfuse payload
|
|
assert mock_post.call_count >= 1
|
|
url = mock_post.call_args[0][0]
|
|
request_body = mock_post.call_args[1].get("content")
|
|
|
|
# Parse the JSON body
|
|
actual_payload = json.loads(request_body)
|
|
print("\nLangfuse payload:", json.dumps(actual_payload, indent=2))
|
|
|
|
# Look for the guardrail span in the payload
|
|
guardrail_span = None
|
|
for item in actual_payload["batch"]:
|
|
if (item["type"] == "span-create" and
|
|
item["body"].get("name") == "guardrail"):
|
|
guardrail_span = item
|
|
break
|
|
|
|
# Assert that the guardrail span exists
|
|
assert guardrail_span is not None, "No guardrail span found in Langfuse payload"
|
|
|
|
# Validate the structure of the guardrail span
|
|
assert guardrail_span["body"]["name"] == "guardrail"
|
|
assert "metadata" in guardrail_span["body"]
|
|
assert guardrail_span["body"]["metadata"]["guardrail_name"] == "presidio_guard"
|
|
assert guardrail_span["body"]["metadata"]["guardrail_mode"] == GuardrailEventHooks.pre_call
|
|
assert "guardrail_masked_entity_count" in guardrail_span["body"]["metadata"]
|
|
assert guardrail_span["body"]["metadata"]["guardrail_masked_entity_count"]["PHONE_NUMBER"] == 1
|
|
|
|
# Validate the output format matches the expected structure
|
|
assert "output" in guardrail_span["body"]
|
|
assert isinstance(guardrail_span["body"]["output"], list)
|
|
assert len(guardrail_span["body"]["output"]) > 0
|
|
|
|
# Validate the first output item has the expected structure
|
|
output_item = guardrail_span["body"]["output"][0]
|
|
assert "entity_type" in output_item
|
|
assert output_item["entity_type"] == "PHONE_NUMBER"
|
|
assert "score" in output_item
|
|
assert "start" in output_item
|
|
assert "end" in output_item |