Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,180 @@
|
||||
import sys
|
||||
import os
|
||||
import io, asyncio
|
||||
import json
|
||||
import pytest
|
||||
import time
|
||||
from litellm import mock_completion
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
import litellm
|
||||
from litellm.proxy.guardrails.guardrail_hooks.presidio import _OPTIONAL_PresidioPIIMasking, PresidioPerRequestConfig
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import StandardLoggingPayload, StandardLoggingGuardrailInformation
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class TestCustomLogger(CustomLogger):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.standard_logging_payload: Optional[StandardLoggingPayload] = None
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
self.standard_logging_payload = kwargs.get("standard_logging_object")
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_standard_logging_payload_includes_guardrail_information():
|
||||
"""
|
||||
Test that the standard logging payload includes the guardrail information when a guardrail is applied
|
||||
"""
|
||||
test_custom_logger = TestCustomLogger()
|
||||
litellm.callbacks = [test_custom_logger]
|
||||
presidio_guard = _OPTIONAL_PresidioPIIMasking(
|
||||
guardrail_name="presidio_guard",
|
||||
event_hook=GuardrailEventHooks.pre_call,
|
||||
presidio_analyzer_api_base=os.getenv("PRESIDIO_ANALYZER_API_BASE"),
|
||||
presidio_anonymizer_api_base=os.getenv("PRESIDIO_ANONYMIZER_API_BASE"),
|
||||
)
|
||||
# 1. call the pre call hook with guardrail
|
||||
request_data = {
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, my phone number is +1 412 555 1212"},
|
||||
],
|
||||
"mock_response": "Hello",
|
||||
"guardrails": ["presidio_guard"],
|
||||
"metadata": {},
|
||||
}
|
||||
await presidio_guard.async_pre_call_hook(
|
||||
user_api_key_dict={},
|
||||
cache=None,
|
||||
data=request_data,
|
||||
call_type="acompletion"
|
||||
)
|
||||
|
||||
# 2. call litellm.acompletion
|
||||
response = await litellm.acompletion(**request_data)
|
||||
|
||||
# 3. assert that the standard logging payload includes the guardrail information
|
||||
await asyncio.sleep(1)
|
||||
print("got standard logging payload=", json.dumps(test_custom_logger.standard_logging_payload, indent=4, default=str))
|
||||
assert test_custom_logger.standard_logging_payload is not None
|
||||
assert test_custom_logger.standard_logging_payload["guardrail_information"] is not None
|
||||
assert test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_name"] == "presidio_guard"
|
||||
assert test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_mode"] == GuardrailEventHooks.pre_call
|
||||
|
||||
# assert that the guardrail_response is a response from presidio analyze
|
||||
presidio_response = test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_response"]
|
||||
assert isinstance(presidio_response, list)
|
||||
for response_item in presidio_response:
|
||||
assert "analysis_explanation" in response_item
|
||||
assert "start" in response_item
|
||||
assert "end" in response_item
|
||||
assert "score" in response_item
|
||||
assert "entity_type" in response_item
|
||||
|
||||
# assert that the duration is not None
|
||||
assert test_custom_logger.standard_logging_payload["guardrail_information"]["duration"] is not None
|
||||
assert test_custom_logger.standard_logging_payload["guardrail_information"]["duration"] > 0
|
||||
|
||||
# assert that we get the count of masked entities
|
||||
assert test_custom_logger.standard_logging_payload["guardrail_information"]["masked_entity_count"] is not None
|
||||
assert test_custom_logger.standard_logging_payload["guardrail_information"]["masked_entity_count"]["PHONE_NUMBER"] == 1
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Local only test")
|
||||
async def test_langfuse_trace_includes_guardrail_information():
|
||||
"""
|
||||
Test that the langfuse trace includes the guardrail information when a guardrail is applied
|
||||
"""
|
||||
import httpx
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from litellm.integrations.langfuse.langfuse_prompt_management import LangfusePromptManagement
|
||||
callback = LangfusePromptManagement(flush_interval=3)
|
||||
import json
|
||||
|
||||
# Create a mock Response object
|
||||
mock_response = AsyncMock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"status": "success"}
|
||||
|
||||
# Create mock for httpx.Client.post
|
||||
mock_post = AsyncMock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with patch("httpx.Client.post", mock_post):
|
||||
litellm._turn_on_debug()
|
||||
litellm.callbacks = [callback]
|
||||
presidio_guard = _OPTIONAL_PresidioPIIMasking(
|
||||
guardrail_name="presidio_guard",
|
||||
event_hook=GuardrailEventHooks.pre_call,
|
||||
presidio_analyzer_api_base=os.getenv("PRESIDIO_ANALYZER_API_BASE"),
|
||||
presidio_anonymizer_api_base=os.getenv("PRESIDIO_ANONYMIZER_API_BASE"),
|
||||
)
|
||||
# 1. call the pre call hook with guardrail
|
||||
request_data = {
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, my phone number is +1 412 555 1212"},
|
||||
],
|
||||
"mock_response": "Hello",
|
||||
"guardrails": ["presidio_guard"],
|
||||
"metadata": {},
|
||||
}
|
||||
await presidio_guard.async_pre_call_hook(
|
||||
user_api_key_dict={},
|
||||
cache=None,
|
||||
data=request_data,
|
||||
call_type="acompletion"
|
||||
)
|
||||
|
||||
# 2. call litellm.acompletion
|
||||
response = await litellm.acompletion(**request_data)
|
||||
|
||||
# 3. Wait for async logging operations to complete
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# 4. Verify the Langfuse payload
|
||||
assert mock_post.call_count >= 1
|
||||
url = mock_post.call_args[0][0]
|
||||
request_body = mock_post.call_args[1].get("content")
|
||||
|
||||
# Parse the JSON body
|
||||
actual_payload = json.loads(request_body)
|
||||
print("\nLangfuse payload:", json.dumps(actual_payload, indent=2))
|
||||
|
||||
# Look for the guardrail span in the payload
|
||||
guardrail_span = None
|
||||
for item in actual_payload["batch"]:
|
||||
if (item["type"] == "span-create" and
|
||||
item["body"].get("name") == "guardrail"):
|
||||
guardrail_span = item
|
||||
break
|
||||
|
||||
# Assert that the guardrail span exists
|
||||
assert guardrail_span is not None, "No guardrail span found in Langfuse payload"
|
||||
|
||||
# Validate the structure of the guardrail span
|
||||
assert guardrail_span["body"]["name"] == "guardrail"
|
||||
assert "metadata" in guardrail_span["body"]
|
||||
assert guardrail_span["body"]["metadata"]["guardrail_name"] == "presidio_guard"
|
||||
assert guardrail_span["body"]["metadata"]["guardrail_mode"] == GuardrailEventHooks.pre_call
|
||||
assert "guardrail_masked_entity_count" in guardrail_span["body"]["metadata"]
|
||||
assert guardrail_span["body"]["metadata"]["guardrail_masked_entity_count"]["PHONE_NUMBER"] == 1
|
||||
|
||||
# Validate the output format matches the expected structure
|
||||
assert "output" in guardrail_span["body"]
|
||||
assert isinstance(guardrail_span["body"]["output"], list)
|
||||
assert len(guardrail_span["body"]["output"]) > 0
|
||||
|
||||
# Validate the first output item has the expected structure
|
||||
output_item = guardrail_span["body"]["output"][0]
|
||||
assert "entity_type" in output_item
|
||||
assert output_item["entity_type"] == "PHONE_NUMBER"
|
||||
assert "score" in output_item
|
||||
assert "start" in output_item
|
||||
assert "end" in output_item
|
Reference in New Issue
Block a user