Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,492 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import unittest.mock as mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../../.."))
|
||||
from litellm_enterprise.enterprise_callbacks.send_emails.base_email import (
|
||||
BaseEmailLogger,
|
||||
)
|
||||
from litellm_enterprise.types.enterprise_callbacks.send_emails import (
|
||||
EmailEvent,
|
||||
SendKeyCreatedEmailEvent,
|
||||
)
|
||||
|
||||
from litellm.integrations.email_templates.email_footer import EMAIL_FOOTER
|
||||
from litellm.proxy._types import Litellm_EntityType, WebhookEvent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_email_logger():
|
||||
return BaseEmailLogger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_send_email():
|
||||
with mock.patch.object(BaseEmailLogger, "send_email") as mock_send:
|
||||
yield mock_send
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lookup_user_email():
|
||||
with mock.patch.object(
|
||||
BaseEmailLogger, "_lookup_user_email_from_db"
|
||||
) as mock_lookup:
|
||||
yield mock_lookup
|
||||
|
||||
|
||||
def test_format_key_budget(base_email_logger):
|
||||
# Test with budget
|
||||
assert base_email_logger._format_key_budget(100.0) == "$100.0"
|
||||
|
||||
# Test with no budget
|
||||
assert base_email_logger._format_key_budget(None) == "No budget"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_key_created_email(
|
||||
base_email_logger, mock_send_email, mock_lookup_user_email
|
||||
):
|
||||
# Setup test data
|
||||
event = SendKeyCreatedEmailEvent(
|
||||
user_id="test_user",
|
||||
user_email="test@example.com",
|
||||
virtual_key="test_key",
|
||||
max_budget=100.0,
|
||||
spend=0.0,
|
||||
event_group=Litellm_EntityType.USER,
|
||||
event="key_created",
|
||||
event_message="Test Key Created",
|
||||
)
|
||||
|
||||
# Mock environment variables
|
||||
with mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"EMAIL_LOGO_URL": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png",
|
||||
"EMAIL_SUPPORT_CONTACT": "support@berri.ai",
|
||||
"PROXY_BASE_URL": "http://test.com",
|
||||
},
|
||||
):
|
||||
# Execute
|
||||
await base_email_logger.send_key_created_email(event)
|
||||
|
||||
# Verify
|
||||
mock_send_email.assert_called_once()
|
||||
call_args = mock_send_email.call_args[1]
|
||||
assert call_args["from_email"] == BaseEmailLogger.DEFAULT_LITELLM_EMAIL
|
||||
assert call_args["to_email"] == ["test@example.com"]
|
||||
assert call_args["subject"] == "LiteLLM: Test Key Created"
|
||||
assert "test_key" in call_args["html_body"]
|
||||
assert "$100.0" in call_args["html_body"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_user_invitation_email(
|
||||
base_email_logger, mock_send_email, mock_lookup_user_email
|
||||
):
|
||||
# Setup test data
|
||||
event = WebhookEvent(
|
||||
user_id="test_user",
|
||||
user_email="invited@example.com",
|
||||
event_group=Litellm_EntityType.USER,
|
||||
event="internal_user_created",
|
||||
event_message="User Invitation",
|
||||
spend=0.0,
|
||||
)
|
||||
|
||||
# Mock environment variables
|
||||
with mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"EMAIL_LOGO_URL": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png",
|
||||
"EMAIL_SUPPORT_CONTACT": "support@berri.ai",
|
||||
"PROXY_BASE_URL": "http://test.com",
|
||||
},
|
||||
):
|
||||
# Execute
|
||||
await base_email_logger.send_user_invitation_email(event)
|
||||
|
||||
# Verify
|
||||
mock_send_email.assert_called_once()
|
||||
call_args = mock_send_email.call_args[1]
|
||||
assert call_args["from_email"] == BaseEmailLogger.DEFAULT_LITELLM_EMAIL
|
||||
assert call_args["to_email"] == ["invited@example.com"]
|
||||
assert call_args["subject"] == "LiteLLM: User Invitation"
|
||||
assert "invited@example.com" in call_args["html_body"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_user_invitation_email_from_db(
|
||||
base_email_logger, mock_send_email, mock_lookup_user_email
|
||||
):
|
||||
# Setup test data with no direct email but one in the database
|
||||
event = WebhookEvent(
|
||||
user_id="test_user",
|
||||
event_group=Litellm_EntityType.USER,
|
||||
event="internal_user_created",
|
||||
event_message="User Invitation",
|
||||
spend=0.0,
|
||||
)
|
||||
|
||||
# Mock the lookup to return an email
|
||||
mock_lookup_user_email.return_value = "db_user@example.com"
|
||||
|
||||
# Mock environment variables
|
||||
with mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"EMAIL_LOGO_URL": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png",
|
||||
"EMAIL_SUPPORT_CONTACT": "support@berri.ai",
|
||||
"PROXY_BASE_URL": "http://test.com",
|
||||
},
|
||||
):
|
||||
# Execute
|
||||
await base_email_logger.send_user_invitation_email(event)
|
||||
|
||||
# Verify
|
||||
mock_lookup_user_email.assert_called_once_with(user_id="test_user")
|
||||
mock_send_email.assert_called_once()
|
||||
call_args = mock_send_email.call_args[1]
|
||||
assert call_args["from_email"] == BaseEmailLogger.DEFAULT_LITELLM_EMAIL
|
||||
assert call_args["to_email"] == ["db_user@example.com"]
|
||||
assert call_args["subject"] == "LiteLLM: User Invitation"
|
||||
assert "db_user@example.com" in call_args["html_body"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_user_invitation_email_no_email(
|
||||
base_email_logger, mock_lookup_user_email
|
||||
):
|
||||
# Setup test data with no email
|
||||
event = WebhookEvent(
|
||||
user_id="test_user",
|
||||
event_group=Litellm_EntityType.USER,
|
||||
event="internal_user_created",
|
||||
event_message="User Invitation",
|
||||
spend=0.0,
|
||||
)
|
||||
|
||||
# Mock lookup to return None
|
||||
mock_lookup_user_email.return_value = None
|
||||
|
||||
# Test that it raises ValueError
|
||||
with pytest.raises(ValueError, match="User email not found"):
|
||||
await base_email_logger.send_user_invitation_email(event)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_key_created_email_no_email(
|
||||
base_email_logger, mock_lookup_user_email
|
||||
):
|
||||
# Setup test data with no email
|
||||
event = SendKeyCreatedEmailEvent(
|
||||
user_id="test_user",
|
||||
user_email=None,
|
||||
virtual_key="test_key",
|
||||
max_budget=100.0,
|
||||
event_message="Test Key Created",
|
||||
event_group=Litellm_EntityType.USER,
|
||||
event="key_created",
|
||||
spend=0.0,
|
||||
)
|
||||
|
||||
# Mock lookup to return None
|
||||
mock_lookup_user_email.return_value = None
|
||||
|
||||
# Test that it raises ValueError
|
||||
with pytest.raises(ValueError, match="User email not found"):
|
||||
await base_email_logger.send_key_created_email(event)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_invitation_link(base_email_logger):
|
||||
# Mock prisma client and its response
|
||||
mock_invitation_row = mock.MagicMock()
|
||||
mock_invitation_row.id = "test-invitation-id"
|
||||
|
||||
mock_prisma = mock.MagicMock()
|
||||
|
||||
# Create an async mock for find_many
|
||||
async def mock_find_many(*args, **kwargs):
|
||||
return [mock_invitation_row]
|
||||
|
||||
mock_prisma.db.litellm_invitationlink.find_many = mock_find_many
|
||||
|
||||
with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
|
||||
# Test with valid user_id
|
||||
result = await base_email_logger._get_invitation_link(
|
||||
user_id="test-user", base_url="http://test.com"
|
||||
)
|
||||
assert result == "http://test.com/ui?invitation_id=test-invitation-id"
|
||||
|
||||
# Test with None user_id
|
||||
result = await base_email_logger._get_invitation_link(
|
||||
user_id=None, base_url="http://test.com"
|
||||
)
|
||||
assert result == "http://test.com"
|
||||
|
||||
# Test with no invitation links
|
||||
async def mock_find_many_empty(*args, **kwargs):
|
||||
return []
|
||||
|
||||
mock_prisma.db.litellm_invitationlink.find_many = mock_find_many_empty
|
||||
result = await base_email_logger._get_invitation_link(
|
||||
user_id="test-user", base_url="http://test.com"
|
||||
)
|
||||
assert result == "http://test.com"
|
||||
|
||||
|
||||
def test_construct_invitation_link(base_email_logger):
|
||||
# Test invitation link construction
|
||||
result = base_email_logger._construct_invitation_link(
|
||||
invitation_id="test-id-123", base_url="http://test.com"
|
||||
)
|
||||
assert result == "http://test.com/ui?invitation_id=test-id-123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_invitation_link_creates_new_when_none_exist(base_email_logger):
|
||||
"""Test that _get_invitation_link creates a new invitation when none exist"""
|
||||
# Mock prisma client with no existing invitation rows
|
||||
mock_prisma = mock.MagicMock()
|
||||
|
||||
# Mock find_many to return empty list (no existing invitations)
|
||||
async def mock_find_many_empty(*args, **kwargs):
|
||||
return []
|
||||
|
||||
mock_prisma.db.litellm_invitationlink.find_many = mock_find_many_empty
|
||||
|
||||
# Mock the create_invitation_for_user function
|
||||
mock_created_invitation = mock.MagicMock()
|
||||
mock_created_invitation.id = "new-invitation-id"
|
||||
|
||||
with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
|
||||
with mock.patch(
|
||||
"litellm.proxy.management_helpers.user_invitation.create_invitation_for_user",
|
||||
return_value=mock_created_invitation
|
||||
) as mock_create_invitation:
|
||||
# Execute
|
||||
result = await base_email_logger._get_invitation_link(
|
||||
user_id="test-user", base_url="http://test.com"
|
||||
)
|
||||
|
||||
# Verify that create_invitation_for_user was called
|
||||
mock_create_invitation.assert_called_once()
|
||||
call_args = mock_create_invitation.call_args[1]
|
||||
assert call_args["data"].user_id == "test-user"
|
||||
assert call_args["user_api_key_dict"].user_id == "test-user"
|
||||
|
||||
# Verify the returned link uses the new invitation ID
|
||||
assert result == "http://test.com/ui?invitation_id=new-invitation-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_invitation_link_uses_existing_when_available(base_email_logger):
|
||||
"""Test that _get_invitation_link uses existing invitation when available"""
|
||||
# Mock prisma client with existing invitation row
|
||||
mock_invitation_row = mock.MagicMock()
|
||||
mock_invitation_row.id = "existing-invitation-id"
|
||||
|
||||
mock_prisma = mock.MagicMock()
|
||||
|
||||
# Mock find_many to return existing invitation
|
||||
async def mock_find_many_existing(*args, **kwargs):
|
||||
return [mock_invitation_row]
|
||||
|
||||
mock_prisma.db.litellm_invitationlink.find_many = mock_find_many_existing
|
||||
|
||||
with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
|
||||
with mock.patch(
|
||||
"litellm.proxy.management_helpers.user_invitation.create_invitation_for_user"
|
||||
) as mock_create_invitation:
|
||||
# Execute
|
||||
result = await base_email_logger._get_invitation_link(
|
||||
user_id="test-user", base_url="http://test.com"
|
||||
)
|
||||
|
||||
# Verify that create_invitation_for_user was NOT called
|
||||
mock_create_invitation.assert_not_called()
|
||||
|
||||
# Verify the returned link uses the existing invitation ID
|
||||
assert result == "http://test.com/ui?invitation_id=existing-invitation-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_invitation_link_creates_new_when_list_is_none(base_email_logger):
|
||||
"""Test that _get_invitation_link creates a new invitation when invitation_rows is None"""
|
||||
# Mock prisma client to return None
|
||||
mock_prisma = mock.MagicMock()
|
||||
|
||||
# Mock find_many to return None
|
||||
async def mock_find_many_none(*args, **kwargs):
|
||||
return None
|
||||
|
||||
mock_prisma.db.litellm_invitationlink.find_many = mock_find_many_none
|
||||
|
||||
# Mock the create_invitation_for_user function
|
||||
mock_created_invitation = mock.MagicMock()
|
||||
mock_created_invitation.id = "new-invitation-from-none"
|
||||
|
||||
with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
|
||||
with mock.patch(
|
||||
"litellm.proxy.management_helpers.user_invitation.create_invitation_for_user",
|
||||
return_value=mock_created_invitation
|
||||
) as mock_create_invitation:
|
||||
# Execute
|
||||
result = await base_email_logger._get_invitation_link(
|
||||
user_id="test-user", base_url="http://test.com"
|
||||
)
|
||||
|
||||
# Verify that create_invitation_for_user was called
|
||||
mock_create_invitation.assert_called_once()
|
||||
call_args = mock_create_invitation.call_args[1]
|
||||
assert call_args["data"].user_id == "test-user"
|
||||
assert call_args["user_api_key_dict"].user_id == "test-user"
|
||||
|
||||
# Verify the returned link uses the new invitation ID
|
||||
assert result == "http://test.com/ui?invitation_id=new-invitation-from-none"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_email_params_user_invitation(
|
||||
base_email_logger, mock_lookup_user_email
|
||||
):
|
||||
# Mock environment variables
|
||||
with mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"EMAIL_LOGO_URL": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png",
|
||||
"EMAIL_SUPPORT_CONTACT": "support@berri.ai",
|
||||
"PROXY_BASE_URL": "http://test.com",
|
||||
},
|
||||
):
|
||||
# Mock invitation link
|
||||
with mock.patch.object(
|
||||
base_email_logger,
|
||||
"_get_invitation_link",
|
||||
return_value="http://test.com/ui?invitation_id=test-id",
|
||||
):
|
||||
# Test with user invitation event
|
||||
result = await base_email_logger._get_email_params(
|
||||
email_event=EmailEvent.new_user_invitation,
|
||||
user_id="test-user",
|
||||
user_email="test@example.com",
|
||||
)
|
||||
|
||||
assert result.logo_url == "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
|
||||
assert result.support_contact == "support@berri.ai"
|
||||
assert result.base_url == "http://test.com/ui?invitation_id=test-id"
|
||||
assert result.recipient_email == "test@example.com"
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env_vars(monkeypatch):
|
||||
"""Set up test environment variables"""
|
||||
monkeypatch.setenv("EMAIL_LOGO_URL", "https://test-company.com/logo.png")
|
||||
monkeypatch.setenv("EMAIL_SUPPORT_CONTACT", "support@test-company.com")
|
||||
monkeypatch.setenv("EMAIL_SIGNATURE", "Best regards,\nTest Company Team")
|
||||
monkeypatch.setenv("EMAIL_SUBJECT_INVITATION", "Welcome to Test Company!")
|
||||
monkeypatch.setenv("EMAIL_SUBJECT_KEY_CREATED", "Your Test Company API Key")
|
||||
monkeypatch.setenv("PROXY_BASE_URL", "http://test.com")
|
||||
monkeypatch.setenv("PROXY_API_URL", "https://test.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_email_params_custom_templates_premium_user(mock_env_vars):
|
||||
"""Test that _get_email_params returns correct values with custom templates for premium users"""
|
||||
# Mock premium_user as True
|
||||
with patch("litellm.proxy.proxy_server.premium_user", True):
|
||||
email_logger = BaseEmailLogger()
|
||||
|
||||
# Test invitation email params
|
||||
invitation_params = await email_logger._get_email_params(
|
||||
email_event=EmailEvent.new_user_invitation,
|
||||
user_id="testid",
|
||||
user_email="test@example.com",
|
||||
event_message="New User Invitation"
|
||||
)
|
||||
|
||||
assert invitation_params.subject == "Welcome to Test Company!"
|
||||
assert invitation_params.signature == "Best regards,\nTest Company Team"
|
||||
assert invitation_params.logo_url == "https://test-company.com/logo.png"
|
||||
assert invitation_params.support_contact == "support@test-company.com"
|
||||
assert invitation_params.base_url == "http://test.com"
|
||||
|
||||
# Test key created email params
|
||||
key_params = await email_logger._get_email_params(
|
||||
email_event=EmailEvent.virtual_key_created,
|
||||
user_id="testid",
|
||||
user_email="test@example.com",
|
||||
event_message="API Key Created"
|
||||
)
|
||||
|
||||
assert key_params.subject == "Your Test Company API Key"
|
||||
assert key_params.signature == "Best regards,\nTest Company Team"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_email_params_non_premium_user(mock_env_vars):
|
||||
"""Test that non-premium users get default templates even when custom ones are provided"""
|
||||
# Mock premium_user as False
|
||||
with patch("litellm.proxy.proxy_server.premium_user", False):
|
||||
email_logger = BaseEmailLogger()
|
||||
|
||||
# Test invitation email params
|
||||
email_params = await email_logger._get_email_params(
|
||||
email_event=EmailEvent.new_user_invitation,
|
||||
user_email="test@example.com",
|
||||
event_message="New User Invitation"
|
||||
)
|
||||
|
||||
# Should use default values even though custom values are set in env
|
||||
assert email_params.subject == "LiteLLM: New User Invitation"
|
||||
assert email_params.signature == EMAIL_FOOTER
|
||||
assert email_params.logo_url == "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
|
||||
assert email_params.support_contact == "support@berri.ai"
|
||||
|
||||
|
||||
# Test key created email params
|
||||
key_params = await email_logger._get_email_params(
|
||||
email_event=EmailEvent.virtual_key_created,
|
||||
user_email="test@example.com",
|
||||
event_message="API Key Created"
|
||||
)
|
||||
|
||||
assert key_params.subject == "LiteLLM: API Key Created"
|
||||
assert key_params.signature == EMAIL_FOOTER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_email_params_default_templates(monkeypatch):
|
||||
"""Test that _get_email_params uses default templates when custom ones aren't provided"""
|
||||
# Clear any existing environment variables
|
||||
monkeypatch.delenv("EMAIL_SUBJECT_INVITATION", raising=False)
|
||||
monkeypatch.delenv("EMAIL_SUBJECT_KEY_CREATED", raising=False)
|
||||
monkeypatch.delenv("EMAIL_SIGNATURE", raising=False)
|
||||
|
||||
# Mock premium_user as True (shouldn't matter since no custom values are set)
|
||||
with patch("litellm.proxy.proxy_server.premium_user", True):
|
||||
email_logger = BaseEmailLogger()
|
||||
|
||||
# Test invitation email params with default template
|
||||
invitation_params = await email_logger._get_email_params(
|
||||
email_event=EmailEvent.new_user_invitation,
|
||||
user_email="test@example.com",
|
||||
event_message="New User Invitation"
|
||||
)
|
||||
|
||||
assert invitation_params.subject == "LiteLLM: New User Invitation"
|
||||
assert invitation_params.signature == EMAIL_FOOTER
|
||||
|
||||
# Test key created email params with default template
|
||||
key_params = await email_logger._get_email_params(
|
||||
email_event=EmailEvent.virtual_key_created,
|
||||
user_email="test@example.com",
|
||||
event_message="API Key Created"
|
||||
)
|
||||
|
||||
assert key_params.subject == "LiteLLM: API Key Created"
|
||||
assert key_params.signature == EMAIL_FOOTER
|
@@ -0,0 +1,265 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import unittest.mock as mock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../../.."))
|
||||
|
||||
from litellm_enterprise.enterprise_callbacks.send_emails.endpoints import (
|
||||
_get_email_settings,
|
||||
_save_email_settings,
|
||||
get_email_event_settings,
|
||||
reset_event_settings,
|
||||
router,
|
||||
update_event_settings,
|
||||
)
|
||||
from litellm_enterprise.types.enterprise_callbacks.send_emails import (
|
||||
DefaultEmailSettings,
|
||||
EmailEvent,
|
||||
EmailEventSettings,
|
||||
EmailEventSettingsUpdateRequest,
|
||||
)
|
||||
|
||||
|
||||
# Mock user_api_key_auth dependency
|
||||
@pytest.fixture
|
||||
def mock_user_api_key_auth():
|
||||
return {"user_id": "test_user"}
|
||||
|
||||
|
||||
# Mock prisma client
|
||||
@pytest.fixture
|
||||
def mock_prisma_client():
|
||||
mock_client = mock.MagicMock()
|
||||
|
||||
# Setup mock for async methods to work properly
|
||||
mock_db = mock.MagicMock()
|
||||
mock_config = mock.MagicMock()
|
||||
|
||||
# Make find_unique return a coroutine mock
|
||||
async def mock_find_unique(*args, **kwargs):
|
||||
return None
|
||||
|
||||
mock_config.find_unique = mock_find_unique
|
||||
|
||||
# Make upsert return a coroutine mock
|
||||
async def mock_upsert(*args, **kwargs):
|
||||
return None
|
||||
|
||||
mock_config.upsert = mock_upsert
|
||||
|
||||
mock_db.litellm_config = mock_config
|
||||
mock_client.db = mock_db
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
# Test _get_email_settings helper function
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_email_settings_empty_db(mock_prisma_client):
|
||||
"""Test that default settings are returned when database has no email settings."""
|
||||
|
||||
# Setup mock find_unique to return None
|
||||
async def mock_find_unique(*args, **kwargs):
|
||||
return None
|
||||
|
||||
mock_prisma_client.db.litellm_config.find_unique = mock_find_unique
|
||||
|
||||
# Call the function
|
||||
result = await _get_email_settings(mock_prisma_client)
|
||||
|
||||
# Assert that default settings are returned
|
||||
assert result == DefaultEmailSettings.get_defaults()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_email_settings_with_existing_settings(mock_prisma_client):
|
||||
"""Test that existing email settings are correctly retrieved from the database."""
|
||||
# Setup mock find_unique to return existing settings
|
||||
mock_settings = {
|
||||
"email_settings": {
|
||||
EmailEvent.virtual_key_created.value: True,
|
||||
EmailEvent.new_user_invitation.value: False,
|
||||
}
|
||||
}
|
||||
|
||||
mock_entry = mock.MagicMock()
|
||||
mock_entry.param_value = json.dumps(mock_settings)
|
||||
|
||||
async def mock_find_unique(*args, **kwargs):
|
||||
return mock_entry
|
||||
|
||||
mock_prisma_client.db.litellm_config.find_unique = mock_find_unique
|
||||
|
||||
# Call the function
|
||||
result = await _get_email_settings(mock_prisma_client)
|
||||
|
||||
# Assert correct settings are returned
|
||||
assert result[EmailEvent.virtual_key_created.value] is True
|
||||
assert result[EmailEvent.new_user_invitation.value] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_email_settings_new_entry(mock_prisma_client):
|
||||
"""Test that email settings are properly saved to database when no previous settings exist."""
|
||||
|
||||
# Setup mock find_unique to return None
|
||||
async def mock_find_unique(*args, **kwargs):
|
||||
return None
|
||||
|
||||
mock_prisma_client.db.litellm_config.find_unique = mock_find_unique
|
||||
|
||||
# Setup mock upsert to return None
|
||||
async def mock_upsert(*args, **kwargs):
|
||||
return None
|
||||
|
||||
mock_prisma_client.db.litellm_config.upsert = mock_upsert
|
||||
|
||||
# Settings to save
|
||||
settings = {
|
||||
EmailEvent.virtual_key_created.value: True,
|
||||
EmailEvent.new_user_invitation.value: False,
|
||||
}
|
||||
|
||||
# Call the function
|
||||
await _save_email_settings(mock_prisma_client, settings)
|
||||
|
||||
# Success if no exception was raised
|
||||
|
||||
|
||||
# Test the GET endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_email_event_settings(mock_prisma_client, mock_user_api_key_auth):
|
||||
"""Test that the GET endpoint returns the correct email event settings."""
|
||||
|
||||
# Mock _get_email_settings to return test data
|
||||
async def mock_get_settings(*args, **kwargs):
|
||||
return {
|
||||
EmailEvent.virtual_key_created.value: True,
|
||||
EmailEvent.new_user_invitation.value: False,
|
||||
}
|
||||
|
||||
# Setup mocks
|
||||
with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client):
|
||||
with mock.patch(
|
||||
"litellm_enterprise.enterprise_callbacks.send_emails.endpoints._get_email_settings",
|
||||
side_effect=mock_get_settings,
|
||||
):
|
||||
# Call the endpoint function directly
|
||||
response = await get_email_event_settings(
|
||||
user_api_key_dict=mock_user_api_key_auth
|
||||
)
|
||||
|
||||
# Assert response contains correct settings
|
||||
assert isinstance(response.dict(), dict)
|
||||
assert "settings" in response.dict()
|
||||
settings = response.dict()["settings"]
|
||||
assert len(settings) == len(EmailEvent)
|
||||
|
||||
# Find the setting for virtual_key_created and check its value
|
||||
virtual_key_setting = next(
|
||||
(s for s in settings if s["event"] == EmailEvent.virtual_key_created),
|
||||
None,
|
||||
)
|
||||
assert virtual_key_setting is not None
|
||||
assert virtual_key_setting["enabled"] is True
|
||||
|
||||
|
||||
# Test the PATCH endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_event_settings(mock_prisma_client, mock_user_api_key_auth):
|
||||
"""Test that the PATCH endpoint correctly updates email event settings."""
|
||||
|
||||
# Mock _get_email_settings to return default settings
|
||||
async def mock_get_settings(*args, **kwargs):
|
||||
return DefaultEmailSettings.get_defaults()
|
||||
|
||||
# Mock _save_email_settings to do nothing
|
||||
async def mock_save_settings(*args, **kwargs):
|
||||
return None
|
||||
|
||||
# Setup mocks
|
||||
with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client):
|
||||
with mock.patch(
|
||||
"litellm_enterprise.enterprise_callbacks.send_emails.endpoints._get_email_settings",
|
||||
side_effect=mock_get_settings,
|
||||
):
|
||||
with mock.patch(
|
||||
"litellm_enterprise.enterprise_callbacks.send_emails.endpoints._save_email_settings",
|
||||
side_effect=mock_save_settings,
|
||||
):
|
||||
# Create request with updated settings
|
||||
request = EmailEventSettingsUpdateRequest(
|
||||
settings=[
|
||||
EmailEventSettings(
|
||||
event=EmailEvent.virtual_key_created, enabled=True
|
||||
),
|
||||
EmailEventSettings(
|
||||
event=EmailEvent.new_user_invitation, enabled=False
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Call the endpoint function directly
|
||||
response = await update_event_settings(
|
||||
request=request, user_api_key_dict=mock_user_api_key_auth
|
||||
)
|
||||
|
||||
# Assert response is success
|
||||
assert (
|
||||
response["message"] == "Email event settings updated successfully"
|
||||
)
|
||||
|
||||
|
||||
# Test the reset endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_event_settings(mock_prisma_client, mock_user_api_key_auth):
|
||||
"""Test that the reset endpoint correctly restores default email event settings."""
|
||||
|
||||
# Mock _save_email_settings to do nothing
|
||||
async def mock_save_settings(*args, **kwargs):
|
||||
return None
|
||||
|
||||
# Setup mocks
|
||||
with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client):
|
||||
with mock.patch(
|
||||
"litellm_enterprise.enterprise_callbacks.send_emails.endpoints._save_email_settings",
|
||||
side_effect=mock_save_settings,
|
||||
):
|
||||
# Call the endpoint function directly
|
||||
response = await reset_event_settings(
|
||||
user_api_key_dict=mock_user_api_key_auth
|
||||
)
|
||||
|
||||
# Assert response is success
|
||||
assert response["message"] == "Email event settings reset to defaults"
|
||||
|
||||
|
||||
# Test handling of prisma client None
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_with_no_prisma_client(mock_user_api_key_auth):
|
||||
"""Test that all endpoints properly handle the case when the database is not connected."""
|
||||
# Setup mock to return None for prisma_client
|
||||
with mock.patch("litellm.proxy.proxy_server.prisma_client", None):
|
||||
# Test get endpoint
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_email_event_settings(user_api_key_dict=mock_user_api_key_auth)
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "Database not connected" in exc_info.value.detail
|
||||
|
||||
# Test update endpoint
|
||||
request = EmailEventSettingsUpdateRequest(settings=[])
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await update_event_settings(
|
||||
request=request, user_api_key_dict=mock_user_api_key_auth
|
||||
)
|
||||
assert exc_info.value.status_code == 500
|
||||
|
||||
# Test reset endpoint
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_event_settings(user_api_key_dict=mock_user_api_key_auth)
|
||||
assert exc_info.value.status_code == 500
|
@@ -0,0 +1,119 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest.mock as mock
|
||||
|
||||
import pytest
|
||||
from httpx import Response
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../../.."))
|
||||
|
||||
from litellm_enterprise.enterprise_callbacks.send_emails.resend_email import (
|
||||
ResendEmailLogger,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env_vars():
|
||||
with mock.patch.dict(os.environ, {"RESEND_API_KEY": "test_api_key"}):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_client():
|
||||
with mock.patch(
|
||||
"litellm_enterprise.enterprise_callbacks.send_emails.resend_email.get_async_httpx_client"
|
||||
) as mock_client:
|
||||
# Create a mock response
|
||||
mock_response = mock.AsyncMock(spec=Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"id": "test_email_id"}
|
||||
|
||||
# Create a mock client
|
||||
mock_async_client = mock.AsyncMock()
|
||||
mock_async_client.post.return_value = mock_response
|
||||
mock_client.return_value = mock_async_client
|
||||
|
||||
yield mock_async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_email_success(mock_env_vars, mock_httpx_client):
|
||||
# Initialize the logger
|
||||
logger = ResendEmailLogger()
|
||||
|
||||
# Test data
|
||||
from_email = "test@example.com"
|
||||
to_email = ["recipient@example.com"]
|
||||
subject = "Test Subject"
|
||||
html_body = "<p>Test email body</p>"
|
||||
|
||||
# Send email
|
||||
await logger.send_email(
|
||||
from_email=from_email, to_email=to_email, subject=subject, html_body=html_body
|
||||
)
|
||||
|
||||
# Verify the HTTP client was called correctly
|
||||
mock_httpx_client.post.assert_called_once()
|
||||
call_args = mock_httpx_client.post.call_args
|
||||
|
||||
# Verify the URL
|
||||
assert call_args[1]["url"] == "https://api.resend.com/emails"
|
||||
|
||||
# Verify the request body
|
||||
request_body = call_args[1]["json"]
|
||||
assert request_body["from"] == from_email
|
||||
assert request_body["to"] == to_email
|
||||
assert request_body["subject"] == subject
|
||||
assert request_body["html"] == html_body
|
||||
|
||||
# Verify the headers
|
||||
assert call_args[1]["headers"] == {"Authorization": "Bearer test_api_key"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_email_missing_api_key(mock_httpx_client):
|
||||
# Remove the API key from environment
|
||||
if "RESEND_API_KEY" in os.environ:
|
||||
del os.environ["RESEND_API_KEY"]
|
||||
|
||||
# Initialize the logger
|
||||
logger = ResendEmailLogger()
|
||||
|
||||
# Test data
|
||||
from_email = "test@example.com"
|
||||
to_email = ["recipient@example.com"]
|
||||
subject = "Test Subject"
|
||||
html_body = "<p>Test email body</p>"
|
||||
|
||||
# Send email
|
||||
await logger.send_email(
|
||||
from_email=from_email, to_email=to_email, subject=subject, html_body=html_body
|
||||
)
|
||||
|
||||
# Verify the HTTP client was called with None as the API key
|
||||
mock_httpx_client.post.assert_called_once()
|
||||
call_args = mock_httpx_client.post.call_args
|
||||
assert call_args[1]["headers"] == {"Authorization": "Bearer None"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_email_multiple_recipients(mock_env_vars, mock_httpx_client):
|
||||
# Initialize the logger
|
||||
logger = ResendEmailLogger()
|
||||
|
||||
# Test data with multiple recipients
|
||||
from_email = "test@example.com"
|
||||
to_email = ["recipient1@example.com", "recipient2@example.com"]
|
||||
subject = "Test Subject"
|
||||
html_body = "<p>Test email body</p>"
|
||||
|
||||
# Send email
|
||||
await logger.send_email(
|
||||
from_email=from_email, to_email=to_email, subject=subject, html_body=html_body
|
||||
)
|
||||
|
||||
# Verify the HTTP client was called with multiple recipients
|
||||
mock_httpx_client.post.assert_called_once()
|
||||
call_args = mock_httpx_client.post.call_args
|
||||
request_body = call_args[1]["json"]
|
||||
assert request_body["to"] == to_email
|
@@ -0,0 +1,216 @@
|
||||
import unittest.mock as mock
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from enterprise.litellm_enterprise.enterprise_callbacks.callback_controls import (
|
||||
EnterpriseCallbackControls,
|
||||
)
|
||||
from litellm.constants import X_LITELLM_DISABLE_CALLBACKS
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.datadog.datadog import DataDogLogger
|
||||
from litellm.integrations.langfuse.langfuse_prompt_management import (
|
||||
LangfusePromptManagement,
|
||||
)
|
||||
from litellm.integrations.s3_v2 import S3Logger
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
|
||||
class TestEnterpriseCallbackControls:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_premium_user(self):
|
||||
"""Fixture to mock premium user check as True"""
|
||||
with patch.object(EnterpriseCallbackControls, '_premium_user_check', return_value=True):
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def mock_non_premium_user(self):
|
||||
"""Fixture to mock premium user check as False"""
|
||||
with patch.object(EnterpriseCallbackControls, '_premium_user_check', return_value=False):
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_headers(self):
|
||||
"""Fixture to mock get_proxy_server_request_headers"""
|
||||
with patch('enterprise.litellm_enterprise.enterprise_callbacks.callback_controls.get_proxy_server_request_headers') as mock_headers:
|
||||
yield mock_headers
|
||||
|
||||
def test_callback_disabled_langfuse_string(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that 'langfuse' string callback is disabled when specified in headers"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is True
|
||||
|
||||
def test_callback_disabled_langfuse_customlogger(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that LangfusePromptManagement CustomLogger instance is disabled when 'langfuse' specified in headers"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
langfuse_logger = LangfusePromptManagement()
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically(langfuse_logger, litellm_params, standard_callback_dynamic_params)
|
||||
assert result is True
|
||||
|
||||
def test_callback_disabled_s3_v2_string(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that 's3_v2' string callback is disabled when specified in headers"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "s3_v2"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is True
|
||||
|
||||
def test_callback_disabled_s3_v2_customlogger(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that S3Logger CustomLogger instance is disabled when 's3_v2' specified in headers"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "s3_v2"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
# Mock S3Logger to avoid async initialization issues
|
||||
with patch('litellm.integrations.s3_v2.S3Logger.__init__', return_value=None):
|
||||
s3_logger = S3Logger()
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically(s3_logger, litellm_params, standard_callback_dynamic_params)
|
||||
assert result is True
|
||||
|
||||
def test_callback_disabled_datadog_string(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that 'datadog' string callback is disabled when specified in headers"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "datadog"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is True
|
||||
|
||||
def test_callback_disabled_datadog_customlogger(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that DataDogLogger CustomLogger instance is disabled when 'datadog' specified in headers"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "datadog"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
# Mock DataDogLogger to avoid async initialization issues
|
||||
with patch('litellm.integrations.datadog.datadog.DataDogLogger.__init__', return_value=None):
|
||||
datadog_logger = DataDogLogger()
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically(datadog_logger, litellm_params, standard_callback_dynamic_params)
|
||||
assert result is True
|
||||
|
||||
def test_multiple_callbacks_disabled(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that multiple callbacks can be disabled with comma-separated list"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse,datadog,s3_v2"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
# Test each callback is disabled
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params) is True
|
||||
|
||||
# Test non-disabled callback is not disabled
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("prometheus", litellm_params, standard_callback_dynamic_params) is False
|
||||
|
||||
def test_callback_not_disabled_when_not_in_list(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that callbacks not in the disabled list are not disabled"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is False
|
||||
|
||||
def test_callback_not_disabled_when_no_header(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that callbacks are not disabled when the header is not present"""
|
||||
mock_request_headers.return_value = {}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is False
|
||||
|
||||
def test_callback_not_disabled_when_header_none(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that callbacks are not disabled when the header value is None"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: None}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is False
|
||||
|
||||
def test_non_premium_user_cannot_disable_callbacks(self, mock_non_premium_user, mock_request_headers):
|
||||
"""Test that non-premium users cannot disable callbacks even with the header"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is False
|
||||
|
||||
def test_case_insensitive_callback_matching(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that callback matching is case insensitive"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "LANGFUSE,DataDog"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
# Test lowercase callbacks are disabled
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True
|
||||
|
||||
def test_whitespace_handling_in_disabled_callbacks(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that whitespace around callback names is handled correctly"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: " langfuse , datadog , s3_v2 "}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params) is True
|
||||
|
||||
def test_custom_logger_not_in_registry(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that CustomLogger not in registry is not disabled"""
|
||||
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "unknown_logger"}
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
# Create a mock CustomLogger that's not in the registry
|
||||
class UnknownLogger(CustomLogger):
|
||||
pass
|
||||
|
||||
unknown_logger = UnknownLogger()
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically(unknown_logger, litellm_params, standard_callback_dynamic_params)
|
||||
assert result is False
|
||||
|
||||
def test_exception_handling(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that exceptions are handled gracefully and return False"""
|
||||
# Make get_proxy_server_request_headers raise an exception
|
||||
mock_request_headers.side_effect = Exception("Test exception")
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is False
|
||||
|
||||
def test_callback_disabled_via_request_body_langfuse(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that callbacks can be disabled via request body litellm_disabled_callbacks"""
|
||||
mock_request_headers.return_value = {} # No headers
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams(litellm_disabled_callbacks=["langfuse"])
|
||||
|
||||
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
|
||||
assert result is True
|
||||
|
||||
def test_callback_disabled_via_request_body_multiple(self, mock_premium_user, mock_request_headers):
|
||||
"""Test that multiple callbacks can be disabled via request body"""
|
||||
mock_request_headers.return_value = {} # No headers
|
||||
litellm_params = {"proxy_server_request": {"url": "test"}}
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams(litellm_disabled_callbacks=["langfuse", "datadog", "s3_v2"])
|
||||
|
||||
# Test each callback is disabled
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params) is True
|
||||
|
||||
# Test non-disabled callback is not disabled
|
||||
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("prometheus", litellm_params, standard_callback_dynamic_params) is False
|
Reference in New Issue
Block a user