Added LiteLLM to the stack

This commit is contained in:
2025-08-18 09:40:50 +00:00
parent 0648c1968c
commit d220b04e32
2682 changed files with 533609 additions and 1 deletions

View File

@@ -0,0 +1,492 @@
import json
import os
import sys
import unittest.mock as mock
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
sys.path.insert(0, os.path.abspath("../../.."))
from litellm_enterprise.enterprise_callbacks.send_emails.base_email import (
BaseEmailLogger,
)
from litellm_enterprise.types.enterprise_callbacks.send_emails import (
EmailEvent,
SendKeyCreatedEmailEvent,
)
from litellm.integrations.email_templates.email_footer import EMAIL_FOOTER
from litellm.proxy._types import Litellm_EntityType, WebhookEvent
@pytest.fixture
def base_email_logger():
return BaseEmailLogger()
@pytest.fixture
def mock_send_email():
with mock.patch.object(BaseEmailLogger, "send_email") as mock_send:
yield mock_send
@pytest.fixture
def mock_lookup_user_email():
with mock.patch.object(
BaseEmailLogger, "_lookup_user_email_from_db"
) as mock_lookup:
yield mock_lookup
def test_format_key_budget(base_email_logger):
# Test with budget
assert base_email_logger._format_key_budget(100.0) == "$100.0"
# Test with no budget
assert base_email_logger._format_key_budget(None) == "No budget"
@pytest.mark.asyncio
async def test_send_key_created_email(
base_email_logger, mock_send_email, mock_lookup_user_email
):
# Setup test data
event = SendKeyCreatedEmailEvent(
user_id="test_user",
user_email="test@example.com",
virtual_key="test_key",
max_budget=100.0,
spend=0.0,
event_group=Litellm_EntityType.USER,
event="key_created",
event_message="Test Key Created",
)
# Mock environment variables
with mock.patch.dict(
os.environ,
{
"EMAIL_LOGO_URL": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png",
"EMAIL_SUPPORT_CONTACT": "support@berri.ai",
"PROXY_BASE_URL": "http://test.com",
},
):
# Execute
await base_email_logger.send_key_created_email(event)
# Verify
mock_send_email.assert_called_once()
call_args = mock_send_email.call_args[1]
assert call_args["from_email"] == BaseEmailLogger.DEFAULT_LITELLM_EMAIL
assert call_args["to_email"] == ["test@example.com"]
assert call_args["subject"] == "LiteLLM: Test Key Created"
assert "test_key" in call_args["html_body"]
assert "$100.0" in call_args["html_body"]
@pytest.mark.asyncio
async def test_send_user_invitation_email(
base_email_logger, mock_send_email, mock_lookup_user_email
):
# Setup test data
event = WebhookEvent(
user_id="test_user",
user_email="invited@example.com",
event_group=Litellm_EntityType.USER,
event="internal_user_created",
event_message="User Invitation",
spend=0.0,
)
# Mock environment variables
with mock.patch.dict(
os.environ,
{
"EMAIL_LOGO_URL": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png",
"EMAIL_SUPPORT_CONTACT": "support@berri.ai",
"PROXY_BASE_URL": "http://test.com",
},
):
# Execute
await base_email_logger.send_user_invitation_email(event)
# Verify
mock_send_email.assert_called_once()
call_args = mock_send_email.call_args[1]
assert call_args["from_email"] == BaseEmailLogger.DEFAULT_LITELLM_EMAIL
assert call_args["to_email"] == ["invited@example.com"]
assert call_args["subject"] == "LiteLLM: User Invitation"
assert "invited@example.com" in call_args["html_body"]
@pytest.mark.asyncio
async def test_send_user_invitation_email_from_db(
base_email_logger, mock_send_email, mock_lookup_user_email
):
# Setup test data with no direct email but one in the database
event = WebhookEvent(
user_id="test_user",
event_group=Litellm_EntityType.USER,
event="internal_user_created",
event_message="User Invitation",
spend=0.0,
)
# Mock the lookup to return an email
mock_lookup_user_email.return_value = "db_user@example.com"
# Mock environment variables
with mock.patch.dict(
os.environ,
{
"EMAIL_LOGO_URL": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png",
"EMAIL_SUPPORT_CONTACT": "support@berri.ai",
"PROXY_BASE_URL": "http://test.com",
},
):
# Execute
await base_email_logger.send_user_invitation_email(event)
# Verify
mock_lookup_user_email.assert_called_once_with(user_id="test_user")
mock_send_email.assert_called_once()
call_args = mock_send_email.call_args[1]
assert call_args["from_email"] == BaseEmailLogger.DEFAULT_LITELLM_EMAIL
assert call_args["to_email"] == ["db_user@example.com"]
assert call_args["subject"] == "LiteLLM: User Invitation"
assert "db_user@example.com" in call_args["html_body"]
@pytest.mark.asyncio
async def test_send_user_invitation_email_no_email(
base_email_logger, mock_lookup_user_email
):
# Setup test data with no email
event = WebhookEvent(
user_id="test_user",
event_group=Litellm_EntityType.USER,
event="internal_user_created",
event_message="User Invitation",
spend=0.0,
)
# Mock lookup to return None
mock_lookup_user_email.return_value = None
# Test that it raises ValueError
with pytest.raises(ValueError, match="User email not found"):
await base_email_logger.send_user_invitation_email(event)
@pytest.mark.asyncio
async def test_send_key_created_email_no_email(
base_email_logger, mock_lookup_user_email
):
# Setup test data with no email
event = SendKeyCreatedEmailEvent(
user_id="test_user",
user_email=None,
virtual_key="test_key",
max_budget=100.0,
event_message="Test Key Created",
event_group=Litellm_EntityType.USER,
event="key_created",
spend=0.0,
)
# Mock lookup to return None
mock_lookup_user_email.return_value = None
# Test that it raises ValueError
with pytest.raises(ValueError, match="User email not found"):
await base_email_logger.send_key_created_email(event)
@pytest.mark.asyncio
async def test_get_invitation_link(base_email_logger):
# Mock prisma client and its response
mock_invitation_row = mock.MagicMock()
mock_invitation_row.id = "test-invitation-id"
mock_prisma = mock.MagicMock()
# Create an async mock for find_many
async def mock_find_many(*args, **kwargs):
return [mock_invitation_row]
mock_prisma.db.litellm_invitationlink.find_many = mock_find_many
with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
# Test with valid user_id
result = await base_email_logger._get_invitation_link(
user_id="test-user", base_url="http://test.com"
)
assert result == "http://test.com/ui?invitation_id=test-invitation-id"
# Test with None user_id
result = await base_email_logger._get_invitation_link(
user_id=None, base_url="http://test.com"
)
assert result == "http://test.com"
# Test with no invitation links
async def mock_find_many_empty(*args, **kwargs):
return []
mock_prisma.db.litellm_invitationlink.find_many = mock_find_many_empty
result = await base_email_logger._get_invitation_link(
user_id="test-user", base_url="http://test.com"
)
assert result == "http://test.com"
def test_construct_invitation_link(base_email_logger):
# Test invitation link construction
result = base_email_logger._construct_invitation_link(
invitation_id="test-id-123", base_url="http://test.com"
)
assert result == "http://test.com/ui?invitation_id=test-id-123"
@pytest.mark.asyncio
async def test_get_invitation_link_creates_new_when_none_exist(base_email_logger):
"""Test that _get_invitation_link creates a new invitation when none exist"""
# Mock prisma client with no existing invitation rows
mock_prisma = mock.MagicMock()
# Mock find_many to return empty list (no existing invitations)
async def mock_find_many_empty(*args, **kwargs):
return []
mock_prisma.db.litellm_invitationlink.find_many = mock_find_many_empty
# Mock the create_invitation_for_user function
mock_created_invitation = mock.MagicMock()
mock_created_invitation.id = "new-invitation-id"
with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
with mock.patch(
"litellm.proxy.management_helpers.user_invitation.create_invitation_for_user",
return_value=mock_created_invitation
) as mock_create_invitation:
# Execute
result = await base_email_logger._get_invitation_link(
user_id="test-user", base_url="http://test.com"
)
# Verify that create_invitation_for_user was called
mock_create_invitation.assert_called_once()
call_args = mock_create_invitation.call_args[1]
assert call_args["data"].user_id == "test-user"
assert call_args["user_api_key_dict"].user_id == "test-user"
# Verify the returned link uses the new invitation ID
assert result == "http://test.com/ui?invitation_id=new-invitation-id"
@pytest.mark.asyncio
async def test_get_invitation_link_uses_existing_when_available(base_email_logger):
"""Test that _get_invitation_link uses existing invitation when available"""
# Mock prisma client with existing invitation row
mock_invitation_row = mock.MagicMock()
mock_invitation_row.id = "existing-invitation-id"
mock_prisma = mock.MagicMock()
# Mock find_many to return existing invitation
async def mock_find_many_existing(*args, **kwargs):
return [mock_invitation_row]
mock_prisma.db.litellm_invitationlink.find_many = mock_find_many_existing
with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
with mock.patch(
"litellm.proxy.management_helpers.user_invitation.create_invitation_for_user"
) as mock_create_invitation:
# Execute
result = await base_email_logger._get_invitation_link(
user_id="test-user", base_url="http://test.com"
)
# Verify that create_invitation_for_user was NOT called
mock_create_invitation.assert_not_called()
# Verify the returned link uses the existing invitation ID
assert result == "http://test.com/ui?invitation_id=existing-invitation-id"
@pytest.mark.asyncio
async def test_get_invitation_link_creates_new_when_list_is_none(base_email_logger):
"""Test that _get_invitation_link creates a new invitation when invitation_rows is None"""
# Mock prisma client to return None
mock_prisma = mock.MagicMock()
# Mock find_many to return None
async def mock_find_many_none(*args, **kwargs):
return None
mock_prisma.db.litellm_invitationlink.find_many = mock_find_many_none
# Mock the create_invitation_for_user function
mock_created_invitation = mock.MagicMock()
mock_created_invitation.id = "new-invitation-from-none"
with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
with mock.patch(
"litellm.proxy.management_helpers.user_invitation.create_invitation_for_user",
return_value=mock_created_invitation
) as mock_create_invitation:
# Execute
result = await base_email_logger._get_invitation_link(
user_id="test-user", base_url="http://test.com"
)
# Verify that create_invitation_for_user was called
mock_create_invitation.assert_called_once()
call_args = mock_create_invitation.call_args[1]
assert call_args["data"].user_id == "test-user"
assert call_args["user_api_key_dict"].user_id == "test-user"
# Verify the returned link uses the new invitation ID
assert result == "http://test.com/ui?invitation_id=new-invitation-from-none"
@pytest.mark.asyncio
async def test_get_email_params_user_invitation(
base_email_logger, mock_lookup_user_email
):
# Mock environment variables
with mock.patch.dict(
os.environ,
{
"EMAIL_LOGO_URL": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png",
"EMAIL_SUPPORT_CONTACT": "support@berri.ai",
"PROXY_BASE_URL": "http://test.com",
},
):
# Mock invitation link
with mock.patch.object(
base_email_logger,
"_get_invitation_link",
return_value="http://test.com/ui?invitation_id=test-id",
):
# Test with user invitation event
result = await base_email_logger._get_email_params(
email_event=EmailEvent.new_user_invitation,
user_id="test-user",
user_email="test@example.com",
)
assert result.logo_url == "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
assert result.support_contact == "support@berri.ai"
assert result.base_url == "http://test.com/ui?invitation_id=test-id"
assert result.recipient_email == "test@example.com"
@pytest.fixture
def mock_env_vars(monkeypatch):
"""Set up test environment variables"""
monkeypatch.setenv("EMAIL_LOGO_URL", "https://test-company.com/logo.png")
monkeypatch.setenv("EMAIL_SUPPORT_CONTACT", "support@test-company.com")
monkeypatch.setenv("EMAIL_SIGNATURE", "Best regards,\nTest Company Team")
monkeypatch.setenv("EMAIL_SUBJECT_INVITATION", "Welcome to Test Company!")
monkeypatch.setenv("EMAIL_SUBJECT_KEY_CREATED", "Your Test Company API Key")
monkeypatch.setenv("PROXY_BASE_URL", "http://test.com")
monkeypatch.setenv("PROXY_API_URL", "https://test.com")
@pytest.mark.asyncio
async def test_get_email_params_custom_templates_premium_user(mock_env_vars):
"""Test that _get_email_params returns correct values with custom templates for premium users"""
# Mock premium_user as True
with patch("litellm.proxy.proxy_server.premium_user", True):
email_logger = BaseEmailLogger()
# Test invitation email params
invitation_params = await email_logger._get_email_params(
email_event=EmailEvent.new_user_invitation,
user_id="testid",
user_email="test@example.com",
event_message="New User Invitation"
)
assert invitation_params.subject == "Welcome to Test Company!"
assert invitation_params.signature == "Best regards,\nTest Company Team"
assert invitation_params.logo_url == "https://test-company.com/logo.png"
assert invitation_params.support_contact == "support@test-company.com"
assert invitation_params.base_url == "http://test.com"
# Test key created email params
key_params = await email_logger._get_email_params(
email_event=EmailEvent.virtual_key_created,
user_id="testid",
user_email="test@example.com",
event_message="API Key Created"
)
assert key_params.subject == "Your Test Company API Key"
assert key_params.signature == "Best regards,\nTest Company Team"
@pytest.mark.asyncio
async def test_get_email_params_non_premium_user(mock_env_vars):
"""Test that non-premium users get default templates even when custom ones are provided"""
# Mock premium_user as False
with patch("litellm.proxy.proxy_server.premium_user", False):
email_logger = BaseEmailLogger()
# Test invitation email params
email_params = await email_logger._get_email_params(
email_event=EmailEvent.new_user_invitation,
user_email="test@example.com",
event_message="New User Invitation"
)
# Should use default values even though custom values are set in env
assert email_params.subject == "LiteLLM: New User Invitation"
assert email_params.signature == EMAIL_FOOTER
assert email_params.logo_url == "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
assert email_params.support_contact == "support@berri.ai"
# Test key created email params
key_params = await email_logger._get_email_params(
email_event=EmailEvent.virtual_key_created,
user_email="test@example.com",
event_message="API Key Created"
)
assert key_params.subject == "LiteLLM: API Key Created"
assert key_params.signature == EMAIL_FOOTER
@pytest.mark.asyncio
async def test_get_email_params_default_templates(monkeypatch):
"""Test that _get_email_params uses default templates when custom ones aren't provided"""
# Clear any existing environment variables
monkeypatch.delenv("EMAIL_SUBJECT_INVITATION", raising=False)
monkeypatch.delenv("EMAIL_SUBJECT_KEY_CREATED", raising=False)
monkeypatch.delenv("EMAIL_SIGNATURE", raising=False)
# Mock premium_user as True (shouldn't matter since no custom values are set)
with patch("litellm.proxy.proxy_server.premium_user", True):
email_logger = BaseEmailLogger()
# Test invitation email params with default template
invitation_params = await email_logger._get_email_params(
email_event=EmailEvent.new_user_invitation,
user_email="test@example.com",
event_message="New User Invitation"
)
assert invitation_params.subject == "LiteLLM: New User Invitation"
assert invitation_params.signature == EMAIL_FOOTER
# Test key created email params with default template
key_params = await email_logger._get_email_params(
email_event=EmailEvent.virtual_key_created,
user_email="test@example.com",
event_message="API Key Created"
)
assert key_params.subject == "LiteLLM: API Key Created"
assert key_params.signature == EMAIL_FOOTER

View File

@@ -0,0 +1,265 @@
import json
import os
import sys
import unittest.mock as mock
import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient
sys.path.insert(0, os.path.abspath("../../.."))
from litellm_enterprise.enterprise_callbacks.send_emails.endpoints import (
_get_email_settings,
_save_email_settings,
get_email_event_settings,
reset_event_settings,
router,
update_event_settings,
)
from litellm_enterprise.types.enterprise_callbacks.send_emails import (
DefaultEmailSettings,
EmailEvent,
EmailEventSettings,
EmailEventSettingsUpdateRequest,
)
# Mock user_api_key_auth dependency
@pytest.fixture
def mock_user_api_key_auth():
return {"user_id": "test_user"}
# Mock prisma client
@pytest.fixture
def mock_prisma_client():
mock_client = mock.MagicMock()
# Setup mock for async methods to work properly
mock_db = mock.MagicMock()
mock_config = mock.MagicMock()
# Make find_unique return a coroutine mock
async def mock_find_unique(*args, **kwargs):
return None
mock_config.find_unique = mock_find_unique
# Make upsert return a coroutine mock
async def mock_upsert(*args, **kwargs):
return None
mock_config.upsert = mock_upsert
mock_db.litellm_config = mock_config
mock_client.db = mock_db
return mock_client
# Test _get_email_settings helper function
@pytest.mark.asyncio
async def test_get_email_settings_empty_db(mock_prisma_client):
"""Test that default settings are returned when database has no email settings."""
# Setup mock find_unique to return None
async def mock_find_unique(*args, **kwargs):
return None
mock_prisma_client.db.litellm_config.find_unique = mock_find_unique
# Call the function
result = await _get_email_settings(mock_prisma_client)
# Assert that default settings are returned
assert result == DefaultEmailSettings.get_defaults()
@pytest.mark.asyncio
async def test_get_email_settings_with_existing_settings(mock_prisma_client):
"""Test that existing email settings are correctly retrieved from the database."""
# Setup mock find_unique to return existing settings
mock_settings = {
"email_settings": {
EmailEvent.virtual_key_created.value: True,
EmailEvent.new_user_invitation.value: False,
}
}
mock_entry = mock.MagicMock()
mock_entry.param_value = json.dumps(mock_settings)
async def mock_find_unique(*args, **kwargs):
return mock_entry
mock_prisma_client.db.litellm_config.find_unique = mock_find_unique
# Call the function
result = await _get_email_settings(mock_prisma_client)
# Assert correct settings are returned
assert result[EmailEvent.virtual_key_created.value] is True
assert result[EmailEvent.new_user_invitation.value] is False
@pytest.mark.asyncio
async def test_save_email_settings_new_entry(mock_prisma_client):
"""Test that email settings are properly saved to database when no previous settings exist."""
# Setup mock find_unique to return None
async def mock_find_unique(*args, **kwargs):
return None
mock_prisma_client.db.litellm_config.find_unique = mock_find_unique
# Setup mock upsert to return None
async def mock_upsert(*args, **kwargs):
return None
mock_prisma_client.db.litellm_config.upsert = mock_upsert
# Settings to save
settings = {
EmailEvent.virtual_key_created.value: True,
EmailEvent.new_user_invitation.value: False,
}
# Call the function
await _save_email_settings(mock_prisma_client, settings)
# Success if no exception was raised
# Test the GET endpoint
@pytest.mark.asyncio
async def test_get_email_event_settings(mock_prisma_client, mock_user_api_key_auth):
"""Test that the GET endpoint returns the correct email event settings."""
# Mock _get_email_settings to return test data
async def mock_get_settings(*args, **kwargs):
return {
EmailEvent.virtual_key_created.value: True,
EmailEvent.new_user_invitation.value: False,
}
# Setup mocks
with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client):
with mock.patch(
"litellm_enterprise.enterprise_callbacks.send_emails.endpoints._get_email_settings",
side_effect=mock_get_settings,
):
# Call the endpoint function directly
response = await get_email_event_settings(
user_api_key_dict=mock_user_api_key_auth
)
# Assert response contains correct settings
assert isinstance(response.dict(), dict)
assert "settings" in response.dict()
settings = response.dict()["settings"]
assert len(settings) == len(EmailEvent)
# Find the setting for virtual_key_created and check its value
virtual_key_setting = next(
(s for s in settings if s["event"] == EmailEvent.virtual_key_created),
None,
)
assert virtual_key_setting is not None
assert virtual_key_setting["enabled"] is True
# Test the PATCH endpoint
@pytest.mark.asyncio
async def test_update_event_settings(mock_prisma_client, mock_user_api_key_auth):
"""Test that the PATCH endpoint correctly updates email event settings."""
# Mock _get_email_settings to return default settings
async def mock_get_settings(*args, **kwargs):
return DefaultEmailSettings.get_defaults()
# Mock _save_email_settings to do nothing
async def mock_save_settings(*args, **kwargs):
return None
# Setup mocks
with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client):
with mock.patch(
"litellm_enterprise.enterprise_callbacks.send_emails.endpoints._get_email_settings",
side_effect=mock_get_settings,
):
with mock.patch(
"litellm_enterprise.enterprise_callbacks.send_emails.endpoints._save_email_settings",
side_effect=mock_save_settings,
):
# Create request with updated settings
request = EmailEventSettingsUpdateRequest(
settings=[
EmailEventSettings(
event=EmailEvent.virtual_key_created, enabled=True
),
EmailEventSettings(
event=EmailEvent.new_user_invitation, enabled=False
),
]
)
# Call the endpoint function directly
response = await update_event_settings(
request=request, user_api_key_dict=mock_user_api_key_auth
)
# Assert response is success
assert (
response["message"] == "Email event settings updated successfully"
)
# Test the reset endpoint
@pytest.mark.asyncio
async def test_reset_event_settings(mock_prisma_client, mock_user_api_key_auth):
"""Test that the reset endpoint correctly restores default email event settings."""
# Mock _save_email_settings to do nothing
async def mock_save_settings(*args, **kwargs):
return None
# Setup mocks
with mock.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client):
with mock.patch(
"litellm_enterprise.enterprise_callbacks.send_emails.endpoints._save_email_settings",
side_effect=mock_save_settings,
):
# Call the endpoint function directly
response = await reset_event_settings(
user_api_key_dict=mock_user_api_key_auth
)
# Assert response is success
assert response["message"] == "Email event settings reset to defaults"
# Test handling of prisma client None
@pytest.mark.asyncio
async def test_endpoint_with_no_prisma_client(mock_user_api_key_auth):
"""Test that all endpoints properly handle the case when the database is not connected."""
# Setup mock to return None for prisma_client
with mock.patch("litellm.proxy.proxy_server.prisma_client", None):
# Test get endpoint
with pytest.raises(HTTPException) as exc_info:
await get_email_event_settings(user_api_key_dict=mock_user_api_key_auth)
assert exc_info.value.status_code == 500
assert "Database not connected" in exc_info.value.detail
# Test update endpoint
request = EmailEventSettingsUpdateRequest(settings=[])
with pytest.raises(HTTPException) as exc_info:
await update_event_settings(
request=request, user_api_key_dict=mock_user_api_key_auth
)
assert exc_info.value.status_code == 500
# Test reset endpoint
with pytest.raises(HTTPException) as exc_info:
await reset_event_settings(user_api_key_dict=mock_user_api_key_auth)
assert exc_info.value.status_code == 500

View File

@@ -0,0 +1,119 @@
import os
import sys
import unittest.mock as mock
import pytest
from httpx import Response
sys.path.insert(0, os.path.abspath("../../.."))
from litellm_enterprise.enterprise_callbacks.send_emails.resend_email import (
ResendEmailLogger,
)
@pytest.fixture
def mock_env_vars():
with mock.patch.dict(os.environ, {"RESEND_API_KEY": "test_api_key"}):
yield
@pytest.fixture
def mock_httpx_client():
with mock.patch(
"litellm_enterprise.enterprise_callbacks.send_emails.resend_email.get_async_httpx_client"
) as mock_client:
# Create a mock response
mock_response = mock.AsyncMock(spec=Response)
mock_response.status_code = 200
mock_response.json.return_value = {"id": "test_email_id"}
# Create a mock client
mock_async_client = mock.AsyncMock()
mock_async_client.post.return_value = mock_response
mock_client.return_value = mock_async_client
yield mock_async_client
@pytest.mark.asyncio
async def test_send_email_success(mock_env_vars, mock_httpx_client):
# Initialize the logger
logger = ResendEmailLogger()
# Test data
from_email = "test@example.com"
to_email = ["recipient@example.com"]
subject = "Test Subject"
html_body = "<p>Test email body</p>"
# Send email
await logger.send_email(
from_email=from_email, to_email=to_email, subject=subject, html_body=html_body
)
# Verify the HTTP client was called correctly
mock_httpx_client.post.assert_called_once()
call_args = mock_httpx_client.post.call_args
# Verify the URL
assert call_args[1]["url"] == "https://api.resend.com/emails"
# Verify the request body
request_body = call_args[1]["json"]
assert request_body["from"] == from_email
assert request_body["to"] == to_email
assert request_body["subject"] == subject
assert request_body["html"] == html_body
# Verify the headers
assert call_args[1]["headers"] == {"Authorization": "Bearer test_api_key"}
@pytest.mark.asyncio
async def test_send_email_missing_api_key(mock_httpx_client):
# Remove the API key from environment
if "RESEND_API_KEY" in os.environ:
del os.environ["RESEND_API_KEY"]
# Initialize the logger
logger = ResendEmailLogger()
# Test data
from_email = "test@example.com"
to_email = ["recipient@example.com"]
subject = "Test Subject"
html_body = "<p>Test email body</p>"
# Send email
await logger.send_email(
from_email=from_email, to_email=to_email, subject=subject, html_body=html_body
)
# Verify the HTTP client was called with None as the API key
mock_httpx_client.post.assert_called_once()
call_args = mock_httpx_client.post.call_args
assert call_args[1]["headers"] == {"Authorization": "Bearer None"}
@pytest.mark.asyncio
async def test_send_email_multiple_recipients(mock_env_vars, mock_httpx_client):
# Initialize the logger
logger = ResendEmailLogger()
# Test data with multiple recipients
from_email = "test@example.com"
to_email = ["recipient1@example.com", "recipient2@example.com"]
subject = "Test Subject"
html_body = "<p>Test email body</p>"
# Send email
await logger.send_email(
from_email=from_email, to_email=to_email, subject=subject, html_body=html_body
)
# Verify the HTTP client was called with multiple recipients
mock_httpx_client.post.assert_called_once()
call_args = mock_httpx_client.post.call_args
request_body = call_args[1]["json"]
assert request_body["to"] == to_email

View File

@@ -0,0 +1,216 @@
import unittest.mock as mock
from typing import cast
from unittest.mock import MagicMock, patch
import pytest
from enterprise.litellm_enterprise.enterprise_callbacks.callback_controls import (
EnterpriseCallbackControls,
)
from litellm.constants import X_LITELLM_DISABLE_CALLBACKS
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.datadog.datadog import DataDogLogger
from litellm.integrations.langfuse.langfuse_prompt_management import (
LangfusePromptManagement,
)
from litellm.integrations.s3_v2 import S3Logger
from litellm.types.utils import StandardCallbackDynamicParams
class TestEnterpriseCallbackControls:
@pytest.fixture
def mock_premium_user(self):
"""Fixture to mock premium user check as True"""
with patch.object(EnterpriseCallbackControls, '_premium_user_check', return_value=True):
yield
@pytest.fixture
def mock_non_premium_user(self):
"""Fixture to mock premium user check as False"""
with patch.object(EnterpriseCallbackControls, '_premium_user_check', return_value=False):
yield
@pytest.fixture
def mock_request_headers(self):
"""Fixture to mock get_proxy_server_request_headers"""
with patch('enterprise.litellm_enterprise.enterprise_callbacks.callback_controls.get_proxy_server_request_headers') as mock_headers:
yield mock_headers
def test_callback_disabled_langfuse_string(self, mock_premium_user, mock_request_headers):
"""Test that 'langfuse' string callback is disabled when specified in headers"""
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"}
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams()
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
assert result is True
def test_callback_disabled_langfuse_customlogger(self, mock_premium_user, mock_request_headers):
"""Test that LangfusePromptManagement CustomLogger instance is disabled when 'langfuse' specified in headers"""
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"}
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams()
langfuse_logger = LangfusePromptManagement()
result = EnterpriseCallbackControls.is_callback_disabled_dynamically(langfuse_logger, litellm_params, standard_callback_dynamic_params)
assert result is True
def test_callback_disabled_s3_v2_string(self, mock_premium_user, mock_request_headers):
"""Test that 's3_v2' string callback is disabled when specified in headers"""
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "s3_v2"}
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams()
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params)
assert result is True
def test_callback_disabled_s3_v2_customlogger(self, mock_premium_user, mock_request_headers):
"""Test that S3Logger CustomLogger instance is disabled when 's3_v2' specified in headers"""
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "s3_v2"}
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams()
# Mock S3Logger to avoid async initialization issues
with patch('litellm.integrations.s3_v2.S3Logger.__init__', return_value=None):
s3_logger = S3Logger()
result = EnterpriseCallbackControls.is_callback_disabled_dynamically(s3_logger, litellm_params, standard_callback_dynamic_params)
assert result is True
def test_callback_disabled_datadog_string(self, mock_premium_user, mock_request_headers):
"""Test that 'datadog' string callback is disabled when specified in headers"""
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "datadog"}
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams()
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params)
assert result is True
def test_callback_disabled_datadog_customlogger(self, mock_premium_user, mock_request_headers):
"""Test that DataDogLogger CustomLogger instance is disabled when 'datadog' specified in headers"""
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "datadog"}
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams()
# Mock DataDogLogger to avoid async initialization issues
with patch('litellm.integrations.datadog.datadog.DataDogLogger.__init__', return_value=None):
datadog_logger = DataDogLogger()
result = EnterpriseCallbackControls.is_callback_disabled_dynamically(datadog_logger, litellm_params, standard_callback_dynamic_params)
assert result is True
def test_multiple_callbacks_disabled(self, mock_premium_user, mock_request_headers):
"""Test that multiple callbacks can be disabled with comma-separated list"""
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse,datadog,s3_v2"}
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams()
# Test each callback is disabled
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params) is True
# Test non-disabled callback is not disabled
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("prometheus", litellm_params, standard_callback_dynamic_params) is False
def test_callback_not_disabled_when_not_in_list(self, mock_premium_user, mock_request_headers):
"""Test that callbacks not in the disabled list are not disabled"""
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"}
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams()
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params)
assert result is False
def test_callback_not_disabled_when_no_header(self, mock_premium_user, mock_request_headers):
"""Test that callbacks are not disabled when the header is not present"""
mock_request_headers.return_value = {}
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams()
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
assert result is False
def test_callback_not_disabled_when_header_none(self, mock_premium_user, mock_request_headers):
"""Test that callbacks are not disabled when the header value is None"""
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: None}
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams()
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
assert result is False
def test_non_premium_user_cannot_disable_callbacks(self, mock_non_premium_user, mock_request_headers):
"""Test that non-premium users cannot disable callbacks even with the header"""
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "langfuse"}
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams()
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
assert result is False
def test_case_insensitive_callback_matching(self, mock_premium_user, mock_request_headers):
"""Test that callback matching is case insensitive"""
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "LANGFUSE,DataDog"}
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams()
# Test lowercase callbacks are disabled
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True
def test_whitespace_handling_in_disabled_callbacks(self, mock_premium_user, mock_request_headers):
"""Test that whitespace around callback names is handled correctly"""
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: " langfuse , datadog , s3_v2 "}
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams()
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params) is True
def test_custom_logger_not_in_registry(self, mock_premium_user, mock_request_headers):
"""Test that CustomLogger not in registry is not disabled"""
mock_request_headers.return_value = {X_LITELLM_DISABLE_CALLBACKS: "unknown_logger"}
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams()
# Create a mock CustomLogger that's not in the registry
class UnknownLogger(CustomLogger):
pass
unknown_logger = UnknownLogger()
result = EnterpriseCallbackControls.is_callback_disabled_dynamically(unknown_logger, litellm_params, standard_callback_dynamic_params)
assert result is False
def test_exception_handling(self, mock_premium_user, mock_request_headers):
"""Test that exceptions are handled gracefully and return False"""
# Make get_proxy_server_request_headers raise an exception
mock_request_headers.side_effect = Exception("Test exception")
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams()
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
assert result is False
def test_callback_disabled_via_request_body_langfuse(self, mock_premium_user, mock_request_headers):
"""Test that callbacks can be disabled via request body litellm_disabled_callbacks"""
mock_request_headers.return_value = {} # No headers
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams(litellm_disabled_callbacks=["langfuse"])
result = EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params)
assert result is True
def test_callback_disabled_via_request_body_multiple(self, mock_premium_user, mock_request_headers):
"""Test that multiple callbacks can be disabled via request body"""
mock_request_headers.return_value = {} # No headers
litellm_params = {"proxy_server_request": {"url": "test"}}
standard_callback_dynamic_params = StandardCallbackDynamicParams(litellm_disabled_callbacks=["langfuse", "datadog", "s3_v2"])
# Test each callback is disabled
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("langfuse", litellm_params, standard_callback_dynamic_params) is True
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("datadog", litellm_params, standard_callback_dynamic_params) is True
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("s3_v2", litellm_params, standard_callback_dynamic_params) is True
# Test non-disabled callback is not disabled
assert EnterpriseCallbackControls.is_callback_disabled_dynamically("prometheus", litellm_params, standard_callback_dynamic_params) is False