Added LiteLLM to the stack

This commit is contained in:
2025-08-18 09:40:50 +00:00
parent 0648c1968c
commit d220b04e32
2682 changed files with 533609 additions and 1 deletions

View File

@@ -0,0 +1,54 @@
# conftest.py
import importlib
import os
import sys
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown():
"""
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
"""
curr_dir = os.getcwd() # Get the current working directory
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the project directory to the system path
import litellm
from litellm import Router
importlib.reload(litellm)
import asyncio
loop = asyncio.get_event_loop_policy().new_event_loop()
asyncio.set_event_loop(loop)
print(litellm)
# from litellm import Router, completion, aembedding, acompletion, embedding
yield
# Teardown code (executes after the yield point)
loop.close() # Close the loop created earlier
asyncio.set_event_loop(None) # Remove the reference to the loop
def pytest_collection_modifyitems(config, items):
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
custom_logger_tests = [
item for item in items if "custom_logger" in item.parent.name
]
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
# Sort tests based on their names
custom_logger_tests.sort(key=lambda x: x.name)
other_tests.sort(key=lambda x: x.name)
# Reorder the items list
items[:] = custom_logger_tests + other_tests

View File

@@ -0,0 +1,24 @@
============================= test session starts ==============================
platform darwin -- Python 3.13.1, pytest-8.3.5, pluggy-1.5.0 -- /Users/krrishdholakia/Documents/litellm/myenv/bin/python3.13
cachedir: .pytest_cache
rootdir: /Users/krrishdholakia/Documents/litellm
configfile: pyproject.toml
plugins: respx-0.22.0, postgresql-7.0.1, anyio-4.4.0, asyncio-0.26.0, mock-3.14.0, ddtrace-2.19.0rc1
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collecting ... collected 3 items
test_supports_tool_choice.py::test_check_provider_match PASSED [ 33%]
test_supports_tool_choice.py::test_supports_tool_choice PASSED [ 66%]
test_supports_tool_choice.py::test_supports_tool_choice_simple_tests PASSED [100%]
=============================== warnings summary ===============================
../../myenv/lib/python3.13/site-packages/pydantic/_internal/_config.py:295
/Users/krrishdholakia/Documents/litellm/myenv/lib/python3.13/site-packages/pydantic/_internal/_config.py:295: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
../../litellm/caching/llm_caching_handler.py:17
/Users/krrishdholakia/Documents/litellm/litellm/caching/llm_caching_handler.py:17: DeprecationWarning: There is no current event loop
event_loop = asyncio.get_event_loop()
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================== 3 passed, 2 warnings in 0.92s =========================

View File

@@ -0,0 +1,124 @@
import asyncio
import copy
import sys
import time
from datetime import datetime
from unittest import mock
from dotenv import load_dotenv
from litellm.types.utils import StandardCallbackDynamicParams
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
import pytest
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
@pytest.mark.asyncio
async def test_client_session_helper():
"""Test that the client session helper handles event loop changes correctly"""
try:
# Create a transport with the new helper
transport = AsyncHTTPHandler._create_aiohttp_transport()
if transport is not None:
print('✅ Successfully created aiohttp transport with helper')
# Test the helper function directly if it's a LiteLLMAiohttpTransport
if hasattr(transport, '_get_valid_client_session'):
session1 = transport._get_valid_client_session() # type: ignore
print(f'✅ First session created: {type(session1).__name__}')
# Call it again to test reuse
session2 = transport._get_valid_client_session() # type: ignore
print(f'✅ Second session call: {type(session2).__name__}')
# In the same event loop, should be the same session
print(f'✅ Same session reused: {session1 is session2}')
return True
else:
print(' No aiohttp transport available (probably missing httpx-aiohttp)')
return True
except Exception as e:
print(f'❌ Error: {e}')
import traceback
traceback.print_exc()
return False
async def test_event_loop_robustness():
"""Test behavior when event loops change (simulating CI/CD scenario)"""
try:
# Test session creation in multiple scenarios
transport = AsyncHTTPHandler._create_aiohttp_transport()
if transport and hasattr(transport, '_get_valid_client_session'):
# Test 1: Normal usage
session = transport._get_valid_client_session() # type: ignore
print(f'✅ Normal session creation works: {session is not None}')
# Test 2: Force recreation by setting client to a callable
from aiohttp import ClientSession
transport.client = lambda: ClientSession() # type: ignore
session2 = transport._get_valid_client_session() # type: ignore
print(f'✅ Session recreation after callable works: {session2 is not None}')
return True
else:
print(' Transport not available or no helper method')
return True
except Exception as e:
print(f'❌ Error in event loop robustness test: {e}')
import traceback
traceback.print_exc()
return False
async def test_httpx_request_simulation():
"""Test that the transport can handle a simulated HTTP request"""
try:
transport = AsyncHTTPHandler._create_aiohttp_transport()
if transport is not None:
print('✅ Transport created for request simulation')
# Create a simple httpx request to test with
import httpx
request = httpx.Request('GET', 'https://httpbin.org/headers')
# Just test that we can get a valid session for this request context
if hasattr(transport, '_get_valid_client_session'):
session = transport._get_valid_client_session() # type: ignore
print(f'✅ Got valid session for request: {session is not None}')
# Test that session has required aiohttp methods
has_request_method = hasattr(session, 'request')
print(f'✅ Session has request method: {has_request_method}')
return has_request_method
return True
else:
print(' No transport available for request simulation')
return True
except Exception as e:
print(f'❌ Error in request simulation: {e}')
return False
if __name__ == "__main__":
print("Testing client session helper and event loop handling fix...")
result1 = asyncio.run(test_client_session_helper())
result2 = asyncio.run(test_event_loop_robustness())
result3 = asyncio.run(test_httpx_request_simulation())
if result1 and result2 and result3:
print("🎉 All tests passed! The helper function approach should fix the CI/CD event loop issues.")
else:
print("💥 Some tests failed")

View File

@@ -0,0 +1,195 @@
# What is this?
import asyncio
import os
import sys
import traceback
from dotenv import load_dotenv
import litellm.types
import litellm.types.utils
load_dotenv()
import io
import sys
import os
# Ensure the project root is in the Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
print("Python Path:", sys.path)
print("Current Working Directory:", os.getcwd())
from typing import Optional
from unittest.mock import MagicMock, patch
import pytest
import uuid
import json
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
def check_aws_credentials():
"""Helper function to check if AWS credentials are set"""
required_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"]
missing_vars = [var for var in required_vars if not os.getenv(var)]
if missing_vars:
pytest.skip(f"Missing required AWS credentials: {', '.join(missing_vars)}")
@pytest.mark.asyncio
async def test_write_and_read_simple_secret():
"""Test writing and reading a simple string secret"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}"
test_secret_value = "test_value_123"
try:
# Write secret
write_response = await secret_manager.async_write_secret(
secret_name=test_secret_name,
secret_value=test_secret_value,
description="LiteLLM Test Secret",
)
print("Write Response:", write_response)
assert write_response is not None
assert "ARN" in write_response
assert "Name" in write_response
assert write_response["Name"] == test_secret_name
# Read secret back
read_value = await secret_manager.async_read_secret(
secret_name=test_secret_name
)
print("Read Value:", read_value)
assert read_value == test_secret_value
finally:
# Cleanup: Delete the secret
delete_response = await secret_manager.async_delete_secret(
secret_name=test_secret_name
)
print("Delete Response:", delete_response)
assert delete_response is not None
@pytest.mark.asyncio
async def test_write_and_read_json_secret():
"""Test writing and reading a JSON structured secret"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}_json"
test_secret_value = {
"api_key": "test_key",
"model": "gpt-4",
"temperature": 0.7,
"metadata": {"team": "ml", "project": "litellm"},
}
try:
# Write JSON secret
write_response = await secret_manager.async_write_secret(
secret_name=test_secret_name,
secret_value=json.dumps(test_secret_value),
description="LiteLLM JSON Test Secret",
)
print("Write Response:", write_response)
# Read and parse JSON secret
read_value = await secret_manager.async_read_secret(
secret_name=test_secret_name
)
parsed_value = json.loads(read_value)
print("Read Value:", read_value)
assert parsed_value == test_secret_value
assert parsed_value["api_key"] == "test_key"
assert parsed_value["metadata"]["team"] == "ml"
finally:
# Cleanup: Delete the secret
delete_response = await secret_manager.async_delete_secret(
secret_name=test_secret_name
)
print("Delete Response:", delete_response)
assert delete_response is not None
@pytest.mark.asyncio
async def test_read_nonexistent_secret():
"""Test reading a secret that doesn't exist"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
nonexistent_secret = f"litellm_nonexistent_{uuid.uuid4().hex}"
response = await secret_manager.async_read_secret(secret_name=nonexistent_secret)
assert response is None
@pytest.mark.asyncio
async def test_primary_secret_functionality():
"""Test storing and retrieving secrets from a primary secret"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
primary_secret_name = f"litellm_test_primary_{uuid.uuid4().hex[:8]}"
# Create a primary secret with multiple key-value pairs
primary_secret_value = {
"api_key_1": "secret_value_1",
"api_key_2": "secret_value_2",
"database_url": "postgresql://user:password@localhost:5432/db",
"nested_secret": json.dumps({"key": "value", "number": 42}),
}
try:
# Write the primary secret
write_response = await secret_manager.async_write_secret(
secret_name=primary_secret_name,
secret_value=json.dumps(primary_secret_value),
description="LiteLLM Test Primary Secret",
)
print("Primary Secret Write Response:", write_response)
assert write_response is not None
assert "ARN" in write_response
assert "Name" in write_response
assert write_response["Name"] == primary_secret_name
# Test reading individual secrets from the primary secret
for key, expected_value in primary_secret_value.items():
# Read using the primary_secret_name parameter
value = await secret_manager.async_read_secret(
secret_name=key, primary_secret_name=primary_secret_name
)
print(f"Read {key} from primary secret:", value)
assert value == expected_value
# Test reading a non-existent key from the primary secret
non_existent_key = "non_existent_key"
value = await secret_manager.async_read_secret(
secret_name=non_existent_key, primary_secret_name=primary_secret_name
)
assert value is None, f"Expected None for non-existent key, got {value}"
finally:
# Cleanup: Delete the primary secret
delete_response = await secret_manager.async_delete_secret(
secret_name=primary_secret_name
)
print("Delete Response:", delete_response)
assert delete_response is not None

View File

@@ -0,0 +1,30 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock, Mock, patch
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm.proxy._types import KeyManagementSystem
from litellm.secret_managers.main import get_secret
class MockSecretClient:
def get_secret(self, secret_name):
return Mock(value="mocked_secret_value")
@pytest.mark.asyncio
async def test_azure_kms():
"""
Basic asserts that the value from get secret is from Azure Key Vault when Key Management System is Azure Key Vault
"""
with patch("litellm.secret_manager_client", new=MockSecretClient()):
litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT
secret = get_secret(secret_name="ishaan-test-key")
assert secret == "mocked_secret_value"

View File

@@ -0,0 +1,213 @@
import os
import sys
import pytest
from dotenv import load_dotenv
load_dotenv()
import os
import httpx
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from unittest.mock import patch, MagicMock
import logging
from litellm._logging import verbose_logger
import uuid
verbose_logger.setLevel(logging.DEBUG)
from litellm.secret_managers.hashicorp_secret_manager import HashicorpSecretManager
hashicorp_secret_manager = HashicorpSecretManager()
mock_vault_response = {
"request_id": "80fafb6a-e96a-4c5b-29fa-ff505ac72201",
"lease_id": "",
"renewable": False,
"lease_duration": 0,
"data": {
"data": {"key": "value-mock"},
"metadata": {
"created_time": "2025-01-01T22:13:50.93942388Z",
"custom_metadata": None,
"deletion_time": "",
"destroyed": False,
"version": 1,
},
},
"wrap_info": None,
"warnings": None,
"auth": None,
"mount_type": "kv",
}
# Update the mock_vault_response for write operations
mock_write_response = {
"request_id": "80fafb6a-e96a-4c5b-29fa-ff505ac72201",
"lease_id": "",
"renewable": False,
"lease_duration": 0,
"data": {
"created_time": "2025-01-04T16:58:42.684673531Z",
"custom_metadata": None,
"deletion_time": "",
"destroyed": False,
"version": 1,
},
"wrap_info": None,
"warnings": None,
"auth": None,
"mount_type": "kv",
}
def test_hashicorp_secret_manager_get_secret():
with patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.get") as mock_get:
# Configure the mock response using MagicMock
mock_response = MagicMock()
mock_response.json.return_value = mock_vault_response
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
# Test the secret manager
secret = hashicorp_secret_manager.sync_read_secret("sample-secret-mock")
assert secret == "value-mock"
# Verify the request was made with correct parameters
mock_get.assert_called_once()
called_url = mock_get.call_args[0][0]
assert "sample-secret-mock" in called_url
assert (
called_url
== "https://test-cluster-public-vault-0f98180c.e98296b2.z1.hashicorp.cloud:8200/v1/admin/secret/data/sample-secret-mock"
)
assert "X-Vault-Token" in mock_get.call_args.kwargs["headers"]
@pytest.mark.asyncio
async def test_hashicorp_secret_manager_write_secret():
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post"
) as mock_post:
# Configure the mock response
mock_response = MagicMock()
mock_response.json.return_value = (
mock_write_response # Use the write-specific response
)
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
# Test the secret manager
secret_name = f"sample-secret-test-{uuid.uuid4()}"
secret_value = f"value-mock-{uuid.uuid4()}"
response = await hashicorp_secret_manager.async_write_secret(
secret_name=secret_name,
secret_value=secret_value,
)
# Verify the response and that the request was made correctly
assert (
response == mock_write_response
) # Compare against write-specific response
mock_post.assert_called_once()
print("CALL ARGS=", mock_post.call_args)
print("call args[1]=", mock_post.call_args[1])
# Verify URL
called_url = mock_post.call_args[1]["url"]
assert secret_name in called_url
assert (
called_url
== f"{hashicorp_secret_manager.vault_addr}/v1/admin/secret/data/{secret_name}"
)
# Verify request body
json_data = mock_post.call_args[1]["json"]
assert "data" in json_data
assert "key" in json_data["data"]
assert json_data["data"]["key"] == secret_value
@pytest.mark.asyncio
async def test_hashicorp_secret_manager_delete_secret():
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.delete"
) as mock_delete:
# Configure the mock response
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
mock_delete.return_value = mock_response
# Test the secret manager
secret_name = f"sample-secret-test-{uuid.uuid4()}"
response = await hashicorp_secret_manager.async_delete_secret(
secret_name=secret_name
)
# Verify the response
assert response == {
"status": "success",
"message": f"Secret {secret_name} deleted successfully",
}
# Verify the request was made correctly
mock_delete.assert_called_once()
# Verify URL
called_url = mock_delete.call_args[1]["url"]
assert secret_name in called_url
assert (
called_url
== f"{hashicorp_secret_manager.vault_addr}/v1/admin/secret/data/{secret_name}"
)
def test_hashicorp_secret_manager_tls_cert_auth(monkeypatch):
monkeypatch.setenv("HCP_VAULT_TOKEN", "test-client-token-12345")
print("HCP_VAULT_TOKEN=", os.getenv("HCP_VAULT_TOKEN"))
# Mock both httpx.post and httpx.Client
with patch("httpx.Client") as mock_client:
# Configure the mock client and response
mock_response = MagicMock()
mock_response.json.return_value = {
"auth": {
"client_token": "test-client-token-12345",
"lease_duration": 3600,
"renewable": True,
}
}
mock_response.raise_for_status.return_value = None
# Configure the mock client's post method
mock_client_instance = MagicMock()
mock_client_instance.post.return_value = mock_response
mock_client.return_value = mock_client_instance
# Create a new instance with TLS cert config
test_manager = HashicorpSecretManager()
test_manager.tls_cert_path = "cert.pem"
test_manager.tls_key_path = "key.pem"
test_manager.vault_cert_role = "test-role"
test_manager.vault_namespace = "test-namespace"
# Test the TLS auth method
token = test_manager._auth_via_tls_cert()
# Verify the token
assert token == "test-client-token-12345"
# Verify Client was created with correct cert tuple
mock_client.assert_called_once_with(cert=("cert.pem", "key.pem"))
# Verify post was called with correct parameters
mock_client_instance.post.assert_called_once_with(
f"{test_manager.vault_addr}/v1/auth/cert/login",
headers={"X-Vault-Namespace": "test-namespace"},
json={"name": "test-role"},
)
# Verify the token was cached
assert test_manager.cache.get_cache("hcp_vault_token") == "test-client-token-12345"

View File

@@ -0,0 +1,116 @@
import json
import os
import sys
import time
from datetime import datetime
from unittest.mock import AsyncMock, patch, MagicMock
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model",
[
"bedrock/mistral.mistral-7b-instruct-v0:2",
"openai/gpt-4o",
"openai/self_hosted",
"bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
],
)
async def test_litellm_overhead(model):
litellm._turn_on_debug()
start_time = datetime.now()
if model == "openai/self_hosted":
response = await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": "Hello, world!"}],
api_base="https://exampleopenaiendpoint-production.up.railway.app/",
)
else:
response = await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": "Hello, world!"}],
)
end_time = datetime.now()
total_time_ms = (end_time - start_time).total_seconds() * 1000
print(response)
print(response._hidden_params)
litellm_overhead_ms = response._hidden_params["litellm_overhead_time_ms"]
# calculate percent of overhead caused by litellm
overhead_percent = litellm_overhead_ms * 100 / total_time_ms
print("##########################\n")
print("total_time_ms", total_time_ms)
print("response litellm_overhead_ms", litellm_overhead_ms)
print("litellm overhead_percent {}%".format(overhead_percent))
print("##########################\n")
assert litellm_overhead_ms > 0
assert litellm_overhead_ms < 1000
# latency overhead should be less than total request time
assert litellm_overhead_ms < (end_time - start_time).total_seconds() * 1000
# latency overhead should be under 40% of total request time
assert overhead_percent < 40
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model",
[
"bedrock/mistral.mistral-7b-instruct-v0:2",
"openai/gpt-4o",
"bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
"openai/self_hosted",
],
)
async def test_litellm_overhead_stream(model):
litellm._turn_on_debug()
start_time = datetime.now()
if model == "openai/self_hosted":
response = await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": "Hello, world!"}],
api_base="https://exampleopenaiendpoint-production.up.railway.app/",
stream=True,
)
else:
response = await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": "Hello, world!"}],
stream=True,
)
async for chunk in response:
print()
end_time = datetime.now()
total_time_ms = (end_time - start_time).total_seconds() * 1000
print(response)
print(response._hidden_params)
litellm_overhead_ms = response._hidden_params["litellm_overhead_time_ms"]
# calculate percent of overhead caused by litellm
overhead_percent = litellm_overhead_ms * 100 / total_time_ms
print("##########################\n")
print("total_time_ms", total_time_ms)
print("response litellm_overhead_ms", litellm_overhead_ms)
print("litellm overhead_percent {}%".format(overhead_percent))
print("##########################\n")
assert litellm_overhead_ms > 0
assert litellm_overhead_ms < 1000
# latency overhead should be less than total request time
assert litellm_overhead_ms < (end_time - start_time).total_seconds() * 1000
# latency overhead should be under 40% of total request time
assert overhead_percent < 40
pass

View File

@@ -0,0 +1,279 @@
import json
import os
import sys
import time
from datetime import datetime
from unittest.mock import AsyncMock, patch, MagicMock
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager
from litellm.integrations.langfuse.langfuse_prompt_management import (
LangfusePromptManagement,
)
from litellm.integrations.opentelemetry import OpenTelemetry
# Test fixtures
@pytest.fixture
def callback_manager():
manager = LoggingCallbackManager()
# Reset callbacks before each test
manager._reset_all_callbacks()
return manager
@pytest.fixture
def mock_custom_logger():
class TestLogger(CustomLogger):
def log_success_event(self, kwargs, response_obj, start_time, end_time):
pass
return TestLogger()
# Test cases
def test_add_string_callback():
"""
Test adding a string callback to litellm.callbacks - only 1 instance of the string callback should be added
"""
manager = LoggingCallbackManager()
test_callback = "test_callback"
# Add string callback
manager.add_litellm_callback(test_callback)
assert test_callback in litellm.callbacks
# Test duplicate prevention
manager.add_litellm_callback(test_callback)
assert litellm.callbacks.count(test_callback) == 1
def test_duplicate_langfuse_logger_test():
manager = LoggingCallbackManager()
for _ in range(10):
langfuse_logger = LangfusePromptManagement()
manager.add_litellm_success_callback(langfuse_logger)
print("litellm.success_callback: ", litellm.success_callback)
assert len(litellm.success_callback) == 1
def test_duplicate_multiple_loggers_test():
manager = LoggingCallbackManager()
for _ in range(10):
langfuse_logger = LangfusePromptManagement()
otel_logger = OpenTelemetry()
manager.add_litellm_success_callback(langfuse_logger)
manager.add_litellm_success_callback(otel_logger)
print("litellm.success_callback: ", litellm.success_callback)
assert len(litellm.success_callback) == 2
# Check exactly one instance of each logger type
langfuse_count = sum(
1
for callback in litellm.success_callback
if isinstance(callback, LangfusePromptManagement)
)
otel_count = sum(
1
for callback in litellm.success_callback
if isinstance(callback, OpenTelemetry)
)
assert (
langfuse_count == 1
), "Should have exactly one LangfusePromptManagement instance"
assert otel_count == 1, "Should have exactly one OpenTelemetry instance"
def test_add_function_callback():
manager = LoggingCallbackManager()
def test_func(kwargs):
pass
# Add function callback
manager.add_litellm_callback(test_func)
assert test_func in litellm.callbacks
# Test duplicate prevention
manager.add_litellm_callback(test_func)
assert litellm.callbacks.count(test_func) == 1
def test_add_custom_logger(mock_custom_logger):
manager = LoggingCallbackManager()
# Add custom logger
manager.add_litellm_callback(mock_custom_logger)
assert mock_custom_logger in litellm.callbacks
def test_add_multiple_callback_types(mock_custom_logger):
manager = LoggingCallbackManager()
def test_func(kwargs):
pass
string_callback = "test_callback"
# Add different types of callbacks
manager.add_litellm_callback(string_callback)
manager.add_litellm_callback(test_func)
manager.add_litellm_callback(mock_custom_logger)
assert string_callback in litellm.callbacks
assert test_func in litellm.callbacks
assert mock_custom_logger in litellm.callbacks
assert len(litellm.callbacks) == 3
def test_success_failure_callbacks():
manager = LoggingCallbackManager()
success_callback = "success_callback"
failure_callback = "failure_callback"
# Add callbacks
manager.add_litellm_success_callback(success_callback)
manager.add_litellm_failure_callback(failure_callback)
assert success_callback in litellm.success_callback
assert failure_callback in litellm.failure_callback
def test_async_callbacks():
manager = LoggingCallbackManager()
async_success = "async_success"
async_failure = "async_failure"
# Add async callbacks
manager.add_litellm_async_success_callback(async_success)
manager.add_litellm_async_failure_callback(async_failure)
assert async_success in litellm._async_success_callback
assert async_failure in litellm._async_failure_callback
def test_remove_callback_from_list_by_object():
manager = LoggingCallbackManager()
# Reset all callbacks
manager._reset_all_callbacks()
def TestObject():
def __init__(self):
manager.add_litellm_callback(self.callback)
manager.add_litellm_success_callback(self.callback)
manager.add_litellm_failure_callback(self.callback)
manager.add_litellm_async_success_callback(self.callback)
manager.add_litellm_async_failure_callback(self.callback)
def callback(self):
pass
obj = TestObject()
manager.remove_callback_from_list_by_object(litellm.callbacks, obj)
manager.remove_callback_from_list_by_object(litellm.success_callback, obj)
manager.remove_callback_from_list_by_object(litellm.failure_callback, obj)
manager.remove_callback_from_list_by_object(litellm._async_success_callback, obj)
manager.remove_callback_from_list_by_object(litellm._async_failure_callback, obj)
# Verify all callback lists are empty
assert len(litellm.callbacks) == 0
assert len(litellm.success_callback) == 0
assert len(litellm.failure_callback) == 0
assert len(litellm._async_success_callback) == 0
assert len(litellm._async_failure_callback) == 0
def test_reset_callbacks(callback_manager):
# Add various callbacks
callback_manager.add_litellm_callback("test")
callback_manager.add_litellm_success_callback("success")
callback_manager.add_litellm_failure_callback("failure")
callback_manager.add_litellm_async_success_callback("async_success")
callback_manager.add_litellm_async_failure_callback("async_failure")
# Reset all callbacks
callback_manager._reset_all_callbacks()
# Verify all callback lists are empty
assert len(litellm.callbacks) == 0
assert len(litellm.success_callback) == 0
assert len(litellm.failure_callback) == 0
assert len(litellm._async_success_callback) == 0
assert len(litellm._async_failure_callback) == 0
@pytest.mark.asyncio
async def test_slack_alerting_callback_registration(callback_manager):
"""
Test that litellm callbacks are correctly registered for slack alerting
when outage_alerts or region_outage_alerts are enabled
"""
from litellm.caching.caching import DualCache
from litellm.proxy.utils import ProxyLogging
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
from unittest.mock import AsyncMock, patch
# Mock the async HTTP handler
with patch('litellm.integrations.SlackAlerting.slack_alerting.get_async_httpx_client') as mock_http:
mock_http.return_value = AsyncMock()
# Create a fresh ProxyLogging instance
proxy_logging = ProxyLogging(user_api_key_cache=DualCache())
# Test 1: No callbacks should be added when alerting is None
proxy_logging.update_values(
alerting=None,
alert_types=["outage_alerts", "region_outage_alerts"]
)
assert len(litellm.callbacks) == 0
# Test 2: Callbacks should be added when slack alerting is enabled with outage alerts
proxy_logging.update_values(
alerting=["slack"],
alert_types=["outage_alerts"]
)
assert len(litellm.callbacks) == 1
assert isinstance(litellm.callbacks[0], SlackAlerting)
# Test 3: Callbacks should be added when slack alerting is enabled with region outage alerts
callback_manager._reset_all_callbacks() # Reset callbacks
proxy_logging.update_values(
alerting=["slack"],
alert_types=["region_outage_alerts"]
)
assert len(litellm.callbacks) == 1
assert isinstance(litellm.callbacks[0], SlackAlerting)
# Test 4: No callbacks should be added for other alert types
callback_manager._reset_all_callbacks() # Reset callbacks
proxy_logging.update_values(
alerting=["slack"],
alert_types=["budget_alerts"] # Some other alert type
)
assert len(litellm.callbacks) == 0
# Test 5: Both success and regular callbacks should be added
callback_manager._reset_all_callbacks() # Reset callbacks
proxy_logging.update_values(
alerting=["slack"],
alert_types=["outage_alerts"]
)
assert len(litellm.callbacks) == 1 # Regular callback for outage alerts
assert len(litellm.success_callback) == 1 # Success callback for response_taking_too_long
assert isinstance(litellm.callbacks[0], SlackAlerting)
# Get the method reference for comparison
response_taking_too_long_callback = proxy_logging.slack_alerting_instance.response_taking_too_long_callback
assert litellm.success_callback[0] == response_taking_too_long_callback
# Cleanup
callback_manager._reset_all_callbacks()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,360 @@
import os
import sys
import time
import traceback
import uuid
from dotenv import load_dotenv
import json
load_dotenv()
import os
import tempfile
from uuid import uuid4
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm.llms.azure.azure import get_azure_ad_token_from_oidc
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
from litellm.secret_managers.main import (
get_secret,
_should_read_secret_from_secret_manager,
)
def load_vertex_ai_credentials():
# Define the path to the vertex_key.json file
print("loading vertex ai credentials")
filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + "/vertex_key.json"
# Read the existing content of the file or create an empty dictionary
try:
with open(vertex_key_path, "r") as file:
# Read the file content
print("Read vertexai file path")
content = file.read()
# If the file is empty or not valid JSON, create an empty dictionary
if not content or not content.strip():
service_account_key_data = {}
else:
# Attempt to load the existing JSON content
file.seek(0)
service_account_key_data = json.load(file)
except FileNotFoundError:
# If the file doesn't exist, create an empty dictionary
service_account_key_data = {}
# Update the service_account_key_data with environment variables
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
private_key = private_key.replace("\\n", "\n")
service_account_key_data["private_key_id"] = private_key_id
service_account_key_data["private_key"] = private_key
# Create a temporary file
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
# Write the updated content to the temporary files
json.dump(service_account_key_data, temp_file, indent=2)
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
def test_aws_secret_manager():
import json
AWSSecretsManagerV2.load_aws_secret_manager(use_aws_secret_manager=True)
secret_val = get_secret("litellm_master_key")
print(f"secret_val: {secret_val}")
# cast json to dict
secret_val = json.loads(secret_val)
assert secret_val["litellm_master_key"] == "sk-1234"
def redact_oidc_signature(secret_val):
# remove the last part of `.` and replace it with "SIGNATURE_REMOVED"
return secret_val.split(".")[:-1] + ["SIGNATURE_REMOVED"]
@pytest.mark.skipif(
os.environ.get("K_SERVICE") is None,
reason="Cannot run without being in GCP Cloud Run",
)
def test_oidc_google():
secret_val = get_secret(
"oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke"
)
print(f"secret_val: {redact_oidc_signature(secret_val)}")
@pytest.mark.skipif(
os.environ.get("ACTIONS_ID_TOKEN_REQUEST_TOKEN") is None,
reason="Cannot run without being in GitHub Actions",
)
def test_oidc_github():
secret_val = get_secret(
"oidc/github/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke"
)
print(f"secret_val: {redact_oidc_signature(secret_val)}")
@pytest.mark.skipif(
os.environ.get("CIRCLE_OIDC_TOKEN") is None,
reason="Cannot run without being in CircleCI Runner",
)
def test_oidc_circleci():
secret_val = get_secret("oidc/circleci/")
print(f"secret_val: {redact_oidc_signature(secret_val)}")
@pytest.mark.skipif(
os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None,
reason="Cannot run without being in CircleCI Runner",
)
def test_oidc_circleci_v2():
secret_val = get_secret(
"oidc/circleci_v2/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke"
)
print(f"secret_val: {redact_oidc_signature(secret_val)}")
@pytest.mark.skipif(
os.environ.get("CIRCLE_OIDC_TOKEN") is None,
reason="Cannot run without being in CircleCI Runner",
)
def test_oidc_circleci_with_azure():
# TODO: Switch to our own Azure account, currently using ai.moda's account
os.environ["AZURE_TENANT_ID"] = "17c0a27a-1246-4aa1-a3b6-d294e80e783c"
os.environ["AZURE_CLIENT_ID"] = "4faf5422-b2bd-45e8-a6d7-46543a38acd0"
azure_ad_token = get_azure_ad_token_from_oidc(
azure_ad_token="oidc/circleci/",
azure_client_id=None,
azure_tenant_id=None,
)
print(f"secret_val: {redact_oidc_signature(azure_ad_token)}")
@pytest.mark.skipif(
os.environ.get("CIRCLE_OIDC_TOKEN") is None,
reason="Cannot run without being in CircleCI Runner",
)
def test_oidc_circle_v1_with_amazon():
# The purpose of this test is to get logs using the older v1 of the CircleCI OIDC token
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
aws_web_identity_token = "oidc/circleci/"
bllm = BedrockLLM()
creds = bllm.get_credentials(
aws_region_name="ca-west-1",
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name="assume-v1-session",
)
@pytest.mark.skipif(
os.environ.get("CIRCLE_OIDC_TOKEN") is None,
reason="Cannot run without being in CircleCI Runner",
)
def test_oidc_circle_v1_with_amazon_fips():
# The purpose of this test is to validate that we can assume a role in a FIPS region
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
aws_web_identity_token = "oidc/circleci/"
bllm = BedrockConverseLLM()
creds = bllm.get_credentials(
aws_region_name="us-west-1",
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name="assume-v1-session-fips",
aws_sts_endpoint="https://sts-fips.us-west-1.amazonaws.com",
)
def test_oidc_env_variable():
# Create a unique environment variable name
env_var_name = "OIDC_TEST_PATH_" + uuid4().hex
os.environ[env_var_name] = "secret-" + uuid4().hex
secret_val = get_secret(f"oidc/env/{env_var_name}")
print(f"secret_val: {redact_oidc_signature(secret_val)}")
assert secret_val == os.environ[env_var_name]
# now unset the environment variable
del os.environ[env_var_name]
def test_oidc_file():
# Create a temporary file
with tempfile.NamedTemporaryFile(mode="w+") as temp_file:
secret_value = "secret-" + uuid4().hex
temp_file.write(secret_value)
temp_file.flush()
temp_file_path = temp_file.name
secret_val = get_secret(f"oidc/file/{temp_file_path}")
print(f"secret_val: {redact_oidc_signature(secret_val)}")
assert secret_val == secret_value
def test_oidc_env_path():
# Create a temporary file
with tempfile.NamedTemporaryFile(mode="w+") as temp_file:
secret_value = "secret-" + uuid4().hex
temp_file.write(secret_value)
temp_file.flush()
temp_file_path = temp_file.name
# Create a unique environment variable name
env_var_name = "OIDC_TEST_PATH_" + uuid4().hex
# Set the environment variable to the temporary file path
os.environ[env_var_name] = temp_file_path
# Test getting the secret using the environment variable
secret_val = get_secret(f"oidc/env_path/{env_var_name}")
print(f"secret_val: {redact_oidc_signature(secret_val)}")
assert secret_val == secret_value
del os.environ[env_var_name]
@pytest.mark.flaky(retries=6, delay=1)
def test_google_secret_manager():
"""
Test that we can get a secret from Google Secret Manager
"""
os.environ["GOOGLE_SECRET_MANAGER_PROJECT_ID"] = "pathrise-convert-1606954137718"
from litellm.secret_managers.google_secret_manager import GoogleSecretManager
load_vertex_ai_credentials()
secret_manager = GoogleSecretManager()
secret_val = secret_manager.get_secret_from_google_secret_manager(
secret_name="OPENAI_API_KEY"
)
print("secret_val: {}".format(secret_val))
assert (
secret_val == "anything"
), "did not get expected secret value. expect 'anything', got '{}'".format(
secret_val
)
def test_google_secret_manager_read_in_memory():
"""
Test that Google Secret manager returs in memory value when it exists
"""
from litellm.secret_managers.google_secret_manager import GoogleSecretManager
load_vertex_ai_credentials()
os.environ["GOOGLE_SECRET_MANAGER_PROJECT_ID"] = "pathrise-convert-1606954137718"
secret_manager = GoogleSecretManager()
secret_manager.cache.cache_dict["UNIQUE_KEY"] = None
secret_manager.cache.cache_dict["UNIQUE_KEY_2"] = "lite-llm"
secret_val = secret_manager.get_secret_from_google_secret_manager(
secret_name="UNIQUE_KEY"
)
print("secret_val: {}".format(secret_val))
assert secret_val == None
secret_val = secret_manager.get_secret_from_google_secret_manager(
secret_name="UNIQUE_KEY_2"
)
print("secret_val: {}".format(secret_val))
assert secret_val == "lite-llm"
def test_should_read_secret_from_secret_manager():
"""
Test that _should_read_secret_from_secret_manager returns correct values based on access mode
"""
from litellm.types.secret_managers.main import KeyManagementSettings
# Test when secret manager client is None
litellm.secret_manager_client = None
litellm._key_management_settings = KeyManagementSettings()
assert _should_read_secret_from_secret_manager() is False
# Test with secret manager client and read_only access
litellm.secret_manager_client = "dummy_client"
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
assert _should_read_secret_from_secret_manager() is True
# Test with secret manager client and read_and_write access
litellm._key_management_settings = KeyManagementSettings(
access_mode="read_and_write"
)
assert _should_read_secret_from_secret_manager() is True
# Test with secret manager client and write_only access
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
assert _should_read_secret_from_secret_manager() is False
# Reset global variables
litellm.secret_manager_client = None
litellm._key_management_settings = KeyManagementSettings()
def test_get_secret_with_access_mode():
"""
Test that get_secret respects access mode settings
"""
from litellm.types.secret_managers.main import KeyManagementSettings
# Set up test environment
test_secret_name = "TEST_SECRET_KEY"
test_secret_value = "test_secret_value"
os.environ[test_secret_name] = test_secret_value
# Test with write_only access (should read from os.environ)
litellm.secret_manager_client = "dummy_client"
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
assert get_secret(test_secret_name) == test_secret_value
# Test with no KeyManagementSettings but secret_manager_client set
litellm.secret_manager_client = "dummy_client"
litellm._key_management_settings = KeyManagementSettings()
assert _should_read_secret_from_secret_manager() is True
# Test with read_only access
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
assert _should_read_secret_from_secret_manager() is True
# Test with read_and_write access
litellm._key_management_settings = KeyManagementSettings(
access_mode="read_and_write"
)
assert _should_read_secret_from_secret_manager() is True
# Reset global variables
litellm.secret_manager_client = None
litellm._key_management_settings = KeyManagementSettings()
del os.environ[test_secret_name]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,63 @@
import pytest
import sys
import os
sys.path.insert(0, os.path.abspath("../.."))
from litellm.utils import validate_chat_completion_tool_choice
def test_validate_tool_choice_none():
"""Test that None is returned as-is."""
result = validate_chat_completion_tool_choice(None)
assert result is None
def test_validate_tool_choice_string():
"""Test that string values are returned as-is."""
assert validate_chat_completion_tool_choice("auto") == "auto"
assert validate_chat_completion_tool_choice("none") == "none"
assert validate_chat_completion_tool_choice("required") == "required"
def test_validate_tool_choice_standard_dict():
"""Test standard OpenAI format with function."""
tool_choice = {"type": "function", "function": {"name": "my_function"}}
result = validate_chat_completion_tool_choice(tool_choice)
assert result == tool_choice
def test_validate_tool_choice_cursor_format():
"""Test Cursor IDE format: {"type": "auto"} -> {"type": "auto"}."""
assert validate_chat_completion_tool_choice({"type": "auto"}) == {"type": "auto"}
assert validate_chat_completion_tool_choice({"type": "none"}) == {"type": "none"}
assert validate_chat_completion_tool_choice({"type": "required"}) == {"type": "required"}
def test_validate_tool_choice_invalid_dict():
"""Test that invalid dict formats raise exceptions."""
# Missing both type and function
with pytest.raises(Exception) as exc_info:
validate_chat_completion_tool_choice({})
assert "Invalid tool choice" in str(exc_info.value)
# Invalid type value
with pytest.raises(Exception) as exc_info:
validate_chat_completion_tool_choice({"type": "invalid"})
assert "Invalid tool choice" in str(exc_info.value)
# Has type but missing function when type is "function"
with pytest.raises(Exception) as exc_info:
validate_chat_completion_tool_choice({"type": "function"})
assert "Invalid tool choice" in str(exc_info.value)
def test_validate_tool_choice_invalid_type():
"""Test that invalid types raise exceptions."""
with pytest.raises(Exception) as exc_info:
validate_chat_completion_tool_choice(123)
assert "Got=<class 'int'>" in str(exc_info.value)
with pytest.raises(Exception) as exc_info:
validate_chat_completion_tool_choice([])
assert "Got=<class 'list'>" in str(exc_info.value)

View File

@@ -0,0 +1,13 @@
{
"type": "service_account",
"project_id": "pathrise-convert-1606954137718",
"private_key_id": "",
"private_key": "",
"client_email": "ci-cd-723@pathrise-convert-1606954137718.iam.gserviceaccount.com",
"client_id": "109577393201924326488",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/ci-cd-723%40pathrise-convert-1606954137718.iam.gserviceaccount.com",
"universe_domain": "googleapis.com"
}