85 lines
3.1 KiB
Python
85 lines
3.1 KiB
Python
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
|
from litellm.proxy._types import CallTypes, UserAPIKeyAuth
|
|
|
|
|
|
class TestCustomGuardrailDeploymentHook:
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_pre_call_deployment_hook_no_guardrails(self):
|
|
"""Test that method returns kwargs unchanged when no guardrails are present"""
|
|
custom_guardrail = CustomGuardrail()
|
|
|
|
# Test with guardrails as None
|
|
kwargs = {
|
|
"messages": [{"role": "user", "content": "test message"}],
|
|
"model": "gpt-3.5-turbo",
|
|
"guardrails": None,
|
|
}
|
|
|
|
result = await custom_guardrail.async_pre_call_deployment_hook(
|
|
kwargs=kwargs, call_type=CallTypes.completion
|
|
)
|
|
|
|
assert result == kwargs
|
|
|
|
# Test with guardrails as non-list
|
|
kwargs["guardrails"] = "not_a_list"
|
|
|
|
result = await custom_guardrail.async_pre_call_deployment_hook(
|
|
kwargs=kwargs, call_type=CallTypes.completion
|
|
)
|
|
|
|
assert result == kwargs
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_pre_call_deployment_hook_with_guardrails_and_message_update(
|
|
self,
|
|
):
|
|
"""Test that method processes guardrails and updates messages when result contains messages"""
|
|
custom_guardrail = CustomGuardrail()
|
|
|
|
# Mock the async_pre_call_hook method
|
|
mock_result = {"messages": [{"role": "user", "content": "filtered message"}]}
|
|
custom_guardrail.async_pre_call_hook = AsyncMock(return_value=mock_result)
|
|
|
|
original_messages = [{"role": "user", "content": "original message"}]
|
|
kwargs = {
|
|
"messages": original_messages,
|
|
"model": "gpt-3.5-turbo",
|
|
"guardrails": ["some_guardrail"],
|
|
"user_api_key_user_id": "test_user",
|
|
"user_api_key_team_id": "test_team",
|
|
"user_api_key_end_user_id": "test_end_user",
|
|
"user_api_key_hash": "test_hash",
|
|
"user_api_key_request_route": "test_route",
|
|
}
|
|
|
|
result = await custom_guardrail.async_pre_call_deployment_hook(
|
|
kwargs=kwargs, call_type=CallTypes.completion
|
|
)
|
|
|
|
# Verify async_pre_call_hook was called with correct parameters
|
|
custom_guardrail.async_pre_call_hook.assert_called_once()
|
|
call_args = custom_guardrail.async_pre_call_hook.call_args
|
|
|
|
# Check that UserAPIKeyAuth was created properly
|
|
user_api_key_dict = call_args[1]["user_api_key_dict"]
|
|
assert isinstance(user_api_key_dict, UserAPIKeyAuth)
|
|
assert user_api_key_dict.user_id == "test_user"
|
|
assert user_api_key_dict.team_id == "test_team"
|
|
assert user_api_key_dict.end_user_id == "test_end_user"
|
|
assert user_api_key_dict.api_key == "test_hash"
|
|
assert user_api_key_dict.request_route == "test_route"
|
|
|
|
# Check other parameters
|
|
assert call_args[1]["data"] == kwargs
|
|
assert call_args[1]["call_type"] == "completion"
|
|
|
|
# Verify messages were updated in result
|
|
assert result["messages"] == mock_result["messages"]
|
|
assert result["messages"] != original_messages
|