Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,84 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.proxy._types import CallTypes, UserAPIKeyAuth
|
||||
|
||||
|
||||
class TestCustomGuardrailDeploymentHook:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_pre_call_deployment_hook_no_guardrails(self):
|
||||
"""Test that method returns kwargs unchanged when no guardrails are present"""
|
||||
custom_guardrail = CustomGuardrail()
|
||||
|
||||
# Test with guardrails as None
|
||||
kwargs = {
|
||||
"messages": [{"role": "user", "content": "test message"}],
|
||||
"model": "gpt-3.5-turbo",
|
||||
"guardrails": None,
|
||||
}
|
||||
|
||||
result = await custom_guardrail.async_pre_call_deployment_hook(
|
||||
kwargs=kwargs, call_type=CallTypes.completion
|
||||
)
|
||||
|
||||
assert result == kwargs
|
||||
|
||||
# Test with guardrails as non-list
|
||||
kwargs["guardrails"] = "not_a_list"
|
||||
|
||||
result = await custom_guardrail.async_pre_call_deployment_hook(
|
||||
kwargs=kwargs, call_type=CallTypes.completion
|
||||
)
|
||||
|
||||
assert result == kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_pre_call_deployment_hook_with_guardrails_and_message_update(
|
||||
self,
|
||||
):
|
||||
"""Test that method processes guardrails and updates messages when result contains messages"""
|
||||
custom_guardrail = CustomGuardrail()
|
||||
|
||||
# Mock the async_pre_call_hook method
|
||||
mock_result = {"messages": [{"role": "user", "content": "filtered message"}]}
|
||||
custom_guardrail.async_pre_call_hook = AsyncMock(return_value=mock_result)
|
||||
|
||||
original_messages = [{"role": "user", "content": "original message"}]
|
||||
kwargs = {
|
||||
"messages": original_messages,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"guardrails": ["some_guardrail"],
|
||||
"user_api_key_user_id": "test_user",
|
||||
"user_api_key_team_id": "test_team",
|
||||
"user_api_key_end_user_id": "test_end_user",
|
||||
"user_api_key_hash": "test_hash",
|
||||
"user_api_key_request_route": "test_route",
|
||||
}
|
||||
|
||||
result = await custom_guardrail.async_pre_call_deployment_hook(
|
||||
kwargs=kwargs, call_type=CallTypes.completion
|
||||
)
|
||||
|
||||
# Verify async_pre_call_hook was called with correct parameters
|
||||
custom_guardrail.async_pre_call_hook.assert_called_once()
|
||||
call_args = custom_guardrail.async_pre_call_hook.call_args
|
||||
|
||||
# Check that UserAPIKeyAuth was created properly
|
||||
user_api_key_dict = call_args[1]["user_api_key_dict"]
|
||||
assert isinstance(user_api_key_dict, UserAPIKeyAuth)
|
||||
assert user_api_key_dict.user_id == "test_user"
|
||||
assert user_api_key_dict.team_id == "test_team"
|
||||
assert user_api_key_dict.end_user_id == "test_end_user"
|
||||
assert user_api_key_dict.api_key == "test_hash"
|
||||
assert user_api_key_dict.request_route == "test_route"
|
||||
|
||||
# Check other parameters
|
||||
assert call_args[1]["data"] == kwargs
|
||||
assert call_args[1]["call_type"] == "completion"
|
||||
|
||||
# Verify messages were updated in result
|
||||
assert result["messages"] == mock_result["messages"]
|
||||
assert result["messages"] != original_messages
|
Reference in New Issue
Block a user