Added LiteLLM to the stack
This commit is contained in:
54
Development/litellm/tests/litellm_utils_tests/conftest.py
Normal file
54
Development/litellm/tests/litellm_utils_tests/conftest.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# conftest.py
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_and_teardown():
|
||||
"""
|
||||
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
|
||||
"""
|
||||
curr_dir = os.getcwd() # Get the current working directory
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the project directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
|
||||
importlib.reload(litellm)
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
print(litellm)
|
||||
# from litellm import Router, completion, aembedding, acompletion, embedding
|
||||
yield
|
||||
|
||||
# Teardown code (executes after the yield point)
|
||||
loop.close() # Close the loop created earlier
|
||||
asyncio.set_event_loop(None) # Remove the reference to the loop
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
|
||||
custom_logger_tests = [
|
||||
item for item in items if "custom_logger" in item.parent.name
|
||||
]
|
||||
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
|
||||
|
||||
# Sort tests based on their names
|
||||
custom_logger_tests.sort(key=lambda x: x.name)
|
||||
other_tests.sort(key=lambda x: x.name)
|
||||
|
||||
# Reorder the items list
|
||||
items[:] = custom_logger_tests + other_tests
|
24
Development/litellm/tests/litellm_utils_tests/log.txt
Normal file
24
Development/litellm/tests/litellm_utils_tests/log.txt
Normal file
@@ -0,0 +1,24 @@
|
||||
============================= test session starts ==============================
|
||||
platform darwin -- Python 3.13.1, pytest-8.3.5, pluggy-1.5.0 -- /Users/krrishdholakia/Documents/litellm/myenv/bin/python3.13
|
||||
cachedir: .pytest_cache
|
||||
rootdir: /Users/krrishdholakia/Documents/litellm
|
||||
configfile: pyproject.toml
|
||||
plugins: respx-0.22.0, postgresql-7.0.1, anyio-4.4.0, asyncio-0.26.0, mock-3.14.0, ddtrace-2.19.0rc1
|
||||
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
|
||||
collecting ... collected 3 items
|
||||
|
||||
test_supports_tool_choice.py::test_check_provider_match PASSED [ 33%]
|
||||
test_supports_tool_choice.py::test_supports_tool_choice PASSED [ 66%]
|
||||
test_supports_tool_choice.py::test_supports_tool_choice_simple_tests PASSED [100%]
|
||||
|
||||
=============================== warnings summary ===============================
|
||||
../../myenv/lib/python3.13/site-packages/pydantic/_internal/_config.py:295
|
||||
/Users/krrishdholakia/Documents/litellm/myenv/lib/python3.13/site-packages/pydantic/_internal/_config.py:295: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
|
||||
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
|
||||
../../litellm/caching/llm_caching_handler.py:17
|
||||
/Users/krrishdholakia/Documents/litellm/litellm/caching/llm_caching_handler.py:17: DeprecationWarning: There is no current event loop
|
||||
event_loop = asyncio.get_event_loop()
|
||||
|
||||
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
|
||||
======================== 3 passed, 2 warnings in 0.92s =========================
|
@@ -0,0 +1,124 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from unittest import mock
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_session_helper():
|
||||
"""Test that the client session helper handles event loop changes correctly"""
|
||||
try:
|
||||
# Create a transport with the new helper
|
||||
transport = AsyncHTTPHandler._create_aiohttp_transport()
|
||||
if transport is not None:
|
||||
print('✅ Successfully created aiohttp transport with helper')
|
||||
|
||||
# Test the helper function directly if it's a LiteLLMAiohttpTransport
|
||||
if hasattr(transport, '_get_valid_client_session'):
|
||||
session1 = transport._get_valid_client_session() # type: ignore
|
||||
print(f'✅ First session created: {type(session1).__name__}')
|
||||
|
||||
# Call it again to test reuse
|
||||
session2 = transport._get_valid_client_session() # type: ignore
|
||||
print(f'✅ Second session call: {type(session2).__name__}')
|
||||
|
||||
# In the same event loop, should be the same session
|
||||
print(f'✅ Same session reused: {session1 is session2}')
|
||||
|
||||
return True
|
||||
else:
|
||||
print('ℹ️ No aiohttp transport available (probably missing httpx-aiohttp)')
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f'❌ Error: {e}')
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def test_event_loop_robustness():
|
||||
"""Test behavior when event loops change (simulating CI/CD scenario)"""
|
||||
try:
|
||||
# Test session creation in multiple scenarios
|
||||
transport = AsyncHTTPHandler._create_aiohttp_transport()
|
||||
|
||||
if transport and hasattr(transport, '_get_valid_client_session'):
|
||||
# Test 1: Normal usage
|
||||
session = transport._get_valid_client_session() # type: ignore
|
||||
print(f'✅ Normal session creation works: {session is not None}')
|
||||
|
||||
# Test 2: Force recreation by setting client to a callable
|
||||
from aiohttp import ClientSession
|
||||
transport.client = lambda: ClientSession() # type: ignore
|
||||
session2 = transport._get_valid_client_session() # type: ignore
|
||||
print(f'✅ Session recreation after callable works: {session2 is not None}')
|
||||
|
||||
return True
|
||||
else:
|
||||
print('ℹ️ Transport not available or no helper method')
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f'❌ Error in event loop robustness test: {e}')
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def test_httpx_request_simulation():
|
||||
"""Test that the transport can handle a simulated HTTP request"""
|
||||
try:
|
||||
transport = AsyncHTTPHandler._create_aiohttp_transport()
|
||||
|
||||
if transport is not None:
|
||||
print('✅ Transport created for request simulation')
|
||||
|
||||
# Create a simple httpx request to test with
|
||||
import httpx
|
||||
request = httpx.Request('GET', 'https://httpbin.org/headers')
|
||||
|
||||
# Just test that we can get a valid session for this request context
|
||||
if hasattr(transport, '_get_valid_client_session'):
|
||||
session = transport._get_valid_client_session() # type: ignore
|
||||
print(f'✅ Got valid session for request: {session is not None}')
|
||||
|
||||
# Test that session has required aiohttp methods
|
||||
has_request_method = hasattr(session, 'request')
|
||||
print(f'✅ Session has request method: {has_request_method}')
|
||||
|
||||
return has_request_method
|
||||
|
||||
return True
|
||||
else:
|
||||
print('ℹ️ No transport available for request simulation')
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f'❌ Error in request simulation: {e}')
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing client session helper and event loop handling fix...")
|
||||
|
||||
result1 = asyncio.run(test_client_session_helper())
|
||||
result2 = asyncio.run(test_event_loop_robustness())
|
||||
result3 = asyncio.run(test_httpx_request_simulation())
|
||||
|
||||
if result1 and result2 and result3:
|
||||
print("🎉 All tests passed! The helper function approach should fix the CI/CD event loop issues.")
|
||||
else:
|
||||
print("💥 Some tests failed")
|
@@ -0,0 +1,195 @@
|
||||
# What is this?
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import litellm.types
|
||||
import litellm.types.utils
|
||||
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Ensure the project root is in the Python path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
|
||||
|
||||
print("Python Path:", sys.path)
|
||||
print("Current Working Directory:", os.getcwd())
|
||||
|
||||
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import uuid
|
||||
import json
|
||||
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
|
||||
|
||||
|
||||
def check_aws_credentials():
|
||||
"""Helper function to check if AWS credentials are set"""
|
||||
required_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"]
|
||||
missing_vars = [var for var in required_vars if not os.getenv(var)]
|
||||
if missing_vars:
|
||||
pytest.skip(f"Missing required AWS credentials: {', '.join(missing_vars)}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_and_read_simple_secret():
|
||||
"""Test writing and reading a simple string secret"""
|
||||
check_aws_credentials()
|
||||
|
||||
secret_manager = AWSSecretsManagerV2()
|
||||
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}"
|
||||
test_secret_value = "test_value_123"
|
||||
|
||||
try:
|
||||
# Write secret
|
||||
write_response = await secret_manager.async_write_secret(
|
||||
secret_name=test_secret_name,
|
||||
secret_value=test_secret_value,
|
||||
description="LiteLLM Test Secret",
|
||||
)
|
||||
|
||||
print("Write Response:", write_response)
|
||||
|
||||
assert write_response is not None
|
||||
assert "ARN" in write_response
|
||||
assert "Name" in write_response
|
||||
assert write_response["Name"] == test_secret_name
|
||||
|
||||
# Read secret back
|
||||
read_value = await secret_manager.async_read_secret(
|
||||
secret_name=test_secret_name
|
||||
)
|
||||
|
||||
print("Read Value:", read_value)
|
||||
|
||||
assert read_value == test_secret_value
|
||||
finally:
|
||||
# Cleanup: Delete the secret
|
||||
delete_response = await secret_manager.async_delete_secret(
|
||||
secret_name=test_secret_name
|
||||
)
|
||||
print("Delete Response:", delete_response)
|
||||
assert delete_response is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_and_read_json_secret():
|
||||
"""Test writing and reading a JSON structured secret"""
|
||||
check_aws_credentials()
|
||||
|
||||
secret_manager = AWSSecretsManagerV2()
|
||||
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}_json"
|
||||
test_secret_value = {
|
||||
"api_key": "test_key",
|
||||
"model": "gpt-4",
|
||||
"temperature": 0.7,
|
||||
"metadata": {"team": "ml", "project": "litellm"},
|
||||
}
|
||||
|
||||
try:
|
||||
# Write JSON secret
|
||||
write_response = await secret_manager.async_write_secret(
|
||||
secret_name=test_secret_name,
|
||||
secret_value=json.dumps(test_secret_value),
|
||||
description="LiteLLM JSON Test Secret",
|
||||
)
|
||||
|
||||
print("Write Response:", write_response)
|
||||
|
||||
# Read and parse JSON secret
|
||||
read_value = await secret_manager.async_read_secret(
|
||||
secret_name=test_secret_name
|
||||
)
|
||||
parsed_value = json.loads(read_value)
|
||||
|
||||
print("Read Value:", read_value)
|
||||
|
||||
assert parsed_value == test_secret_value
|
||||
assert parsed_value["api_key"] == "test_key"
|
||||
assert parsed_value["metadata"]["team"] == "ml"
|
||||
finally:
|
||||
# Cleanup: Delete the secret
|
||||
delete_response = await secret_manager.async_delete_secret(
|
||||
secret_name=test_secret_name
|
||||
)
|
||||
print("Delete Response:", delete_response)
|
||||
assert delete_response is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_nonexistent_secret():
|
||||
"""Test reading a secret that doesn't exist"""
|
||||
check_aws_credentials()
|
||||
|
||||
secret_manager = AWSSecretsManagerV2()
|
||||
nonexistent_secret = f"litellm_nonexistent_{uuid.uuid4().hex}"
|
||||
|
||||
response = await secret_manager.async_read_secret(secret_name=nonexistent_secret)
|
||||
|
||||
assert response is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_primary_secret_functionality():
|
||||
"""Test storing and retrieving secrets from a primary secret"""
|
||||
check_aws_credentials()
|
||||
|
||||
secret_manager = AWSSecretsManagerV2()
|
||||
primary_secret_name = f"litellm_test_primary_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Create a primary secret with multiple key-value pairs
|
||||
primary_secret_value = {
|
||||
"api_key_1": "secret_value_1",
|
||||
"api_key_2": "secret_value_2",
|
||||
"database_url": "postgresql://user:password@localhost:5432/db",
|
||||
"nested_secret": json.dumps({"key": "value", "number": 42}),
|
||||
}
|
||||
|
||||
try:
|
||||
# Write the primary secret
|
||||
write_response = await secret_manager.async_write_secret(
|
||||
secret_name=primary_secret_name,
|
||||
secret_value=json.dumps(primary_secret_value),
|
||||
description="LiteLLM Test Primary Secret",
|
||||
)
|
||||
|
||||
print("Primary Secret Write Response:", write_response)
|
||||
assert write_response is not None
|
||||
assert "ARN" in write_response
|
||||
assert "Name" in write_response
|
||||
assert write_response["Name"] == primary_secret_name
|
||||
|
||||
# Test reading individual secrets from the primary secret
|
||||
for key, expected_value in primary_secret_value.items():
|
||||
# Read using the primary_secret_name parameter
|
||||
value = await secret_manager.async_read_secret(
|
||||
secret_name=key, primary_secret_name=primary_secret_name
|
||||
)
|
||||
|
||||
print(f"Read {key} from primary secret:", value)
|
||||
assert value == expected_value
|
||||
|
||||
# Test reading a non-existent key from the primary secret
|
||||
non_existent_key = "non_existent_key"
|
||||
value = await secret_manager.async_read_secret(
|
||||
secret_name=non_existent_key, primary_secret_name=primary_secret_name
|
||||
)
|
||||
assert value is None, f"Expected None for non-existent key, got {value}"
|
||||
|
||||
finally:
|
||||
# Cleanup: Delete the primary secret
|
||||
delete_response = await secret_manager.async_delete_secret(
|
||||
secret_name=primary_secret_name
|
||||
)
|
||||
print("Delete Response:", delete_response)
|
||||
assert delete_response is not None
|
@@ -0,0 +1,30 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
|
||||
class MockSecretClient:
|
||||
def get_secret(self, secret_name):
|
||||
return Mock(value="mocked_secret_value")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_kms():
|
||||
"""
|
||||
Basic asserts that the value from get secret is from Azure Key Vault when Key Management System is Azure Key Vault
|
||||
"""
|
||||
with patch("litellm.secret_manager_client", new=MockSecretClient()):
|
||||
litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT
|
||||
secret = get_secret(secret_name="ishaan-test-key")
|
||||
assert secret == "mocked_secret_value"
|
213
Development/litellm/tests/litellm_utils_tests/test_hashicorp.py
Normal file
213
Development/litellm/tests/litellm_utils_tests/test_hashicorp.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
import httpx
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from unittest.mock import patch, MagicMock
|
||||
import logging
|
||||
from litellm._logging import verbose_logger
|
||||
import uuid
|
||||
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
|
||||
from litellm.secret_managers.hashicorp_secret_manager import HashicorpSecretManager
|
||||
|
||||
hashicorp_secret_manager = HashicorpSecretManager()
|
||||
|
||||
|
||||
mock_vault_response = {
|
||||
"request_id": "80fafb6a-e96a-4c5b-29fa-ff505ac72201",
|
||||
"lease_id": "",
|
||||
"renewable": False,
|
||||
"lease_duration": 0,
|
||||
"data": {
|
||||
"data": {"key": "value-mock"},
|
||||
"metadata": {
|
||||
"created_time": "2025-01-01T22:13:50.93942388Z",
|
||||
"custom_metadata": None,
|
||||
"deletion_time": "",
|
||||
"destroyed": False,
|
||||
"version": 1,
|
||||
},
|
||||
},
|
||||
"wrap_info": None,
|
||||
"warnings": None,
|
||||
"auth": None,
|
||||
"mount_type": "kv",
|
||||
}
|
||||
|
||||
# Update the mock_vault_response for write operations
|
||||
mock_write_response = {
|
||||
"request_id": "80fafb6a-e96a-4c5b-29fa-ff505ac72201",
|
||||
"lease_id": "",
|
||||
"renewable": False,
|
||||
"lease_duration": 0,
|
||||
"data": {
|
||||
"created_time": "2025-01-04T16:58:42.684673531Z",
|
||||
"custom_metadata": None,
|
||||
"deletion_time": "",
|
||||
"destroyed": False,
|
||||
"version": 1,
|
||||
},
|
||||
"wrap_info": None,
|
||||
"warnings": None,
|
||||
"auth": None,
|
||||
"mount_type": "kv",
|
||||
}
|
||||
|
||||
|
||||
def test_hashicorp_secret_manager_get_secret():
|
||||
with patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.get") as mock_get:
|
||||
# Configure the mock response using MagicMock
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_vault_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
# Test the secret manager
|
||||
secret = hashicorp_secret_manager.sync_read_secret("sample-secret-mock")
|
||||
assert secret == "value-mock"
|
||||
|
||||
# Verify the request was made with correct parameters
|
||||
mock_get.assert_called_once()
|
||||
called_url = mock_get.call_args[0][0]
|
||||
assert "sample-secret-mock" in called_url
|
||||
|
||||
assert (
|
||||
called_url
|
||||
== "https://test-cluster-public-vault-0f98180c.e98296b2.z1.hashicorp.cloud:8200/v1/admin/secret/data/sample-secret-mock"
|
||||
)
|
||||
assert "X-Vault-Token" in mock_get.call_args.kwargs["headers"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hashicorp_secret_manager_write_secret():
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post"
|
||||
) as mock_post:
|
||||
# Configure the mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = (
|
||||
mock_write_response # Use the write-specific response
|
||||
)
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Test the secret manager
|
||||
secret_name = f"sample-secret-test-{uuid.uuid4()}"
|
||||
secret_value = f"value-mock-{uuid.uuid4()}"
|
||||
response = await hashicorp_secret_manager.async_write_secret(
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_value,
|
||||
)
|
||||
|
||||
# Verify the response and that the request was made correctly
|
||||
assert (
|
||||
response == mock_write_response
|
||||
) # Compare against write-specific response
|
||||
mock_post.assert_called_once()
|
||||
print("CALL ARGS=", mock_post.call_args)
|
||||
print("call args[1]=", mock_post.call_args[1])
|
||||
|
||||
# Verify URL
|
||||
called_url = mock_post.call_args[1]["url"]
|
||||
assert secret_name in called_url
|
||||
assert (
|
||||
called_url
|
||||
== f"{hashicorp_secret_manager.vault_addr}/v1/admin/secret/data/{secret_name}"
|
||||
)
|
||||
|
||||
# Verify request body
|
||||
json_data = mock_post.call_args[1]["json"]
|
||||
assert "data" in json_data
|
||||
assert "key" in json_data["data"]
|
||||
assert json_data["data"]["key"] == secret_value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hashicorp_secret_manager_delete_secret():
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.delete"
|
||||
) as mock_delete:
|
||||
# Configure the mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_delete.return_value = mock_response
|
||||
|
||||
# Test the secret manager
|
||||
secret_name = f"sample-secret-test-{uuid.uuid4()}"
|
||||
response = await hashicorp_secret_manager.async_delete_secret(
|
||||
secret_name=secret_name
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response == {
|
||||
"status": "success",
|
||||
"message": f"Secret {secret_name} deleted successfully",
|
||||
}
|
||||
|
||||
# Verify the request was made correctly
|
||||
mock_delete.assert_called_once()
|
||||
|
||||
# Verify URL
|
||||
called_url = mock_delete.call_args[1]["url"]
|
||||
assert secret_name in called_url
|
||||
assert (
|
||||
called_url
|
||||
== f"{hashicorp_secret_manager.vault_addr}/v1/admin/secret/data/{secret_name}"
|
||||
)
|
||||
|
||||
|
||||
def test_hashicorp_secret_manager_tls_cert_auth(monkeypatch):
|
||||
monkeypatch.setenv("HCP_VAULT_TOKEN", "test-client-token-12345")
|
||||
print("HCP_VAULT_TOKEN=", os.getenv("HCP_VAULT_TOKEN"))
|
||||
# Mock both httpx.post and httpx.Client
|
||||
with patch("httpx.Client") as mock_client:
|
||||
# Configure the mock client and response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"auth": {
|
||||
"client_token": "test-client-token-12345",
|
||||
"lease_duration": 3600,
|
||||
"renewable": True,
|
||||
}
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
# Configure the mock client's post method
|
||||
mock_client_instance = MagicMock()
|
||||
mock_client_instance.post.return_value = mock_response
|
||||
mock_client.return_value = mock_client_instance
|
||||
|
||||
# Create a new instance with TLS cert config
|
||||
test_manager = HashicorpSecretManager()
|
||||
test_manager.tls_cert_path = "cert.pem"
|
||||
test_manager.tls_key_path = "key.pem"
|
||||
test_manager.vault_cert_role = "test-role"
|
||||
test_manager.vault_namespace = "test-namespace"
|
||||
|
||||
# Test the TLS auth method
|
||||
token = test_manager._auth_via_tls_cert()
|
||||
|
||||
# Verify the token
|
||||
assert token == "test-client-token-12345"
|
||||
|
||||
# Verify Client was created with correct cert tuple
|
||||
mock_client.assert_called_once_with(cert=("cert.pem", "key.pem"))
|
||||
|
||||
# Verify post was called with correct parameters
|
||||
mock_client_instance.post.assert_called_once_with(
|
||||
f"{test_manager.vault_addr}/v1/auth/cert/login",
|
||||
headers={"X-Vault-Namespace": "test-namespace"},
|
||||
json={"name": "test-role"},
|
||||
)
|
||||
|
||||
# Verify the token was cached
|
||||
assert test_manager.cache.get_cache("hcp_vault_token") == "test-client-token-12345"
|
@@ -0,0 +1,116 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"bedrock/mistral.mistral-7b-instruct-v0:2",
|
||||
"openai/gpt-4o",
|
||||
"openai/self_hosted",
|
||||
"bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
],
|
||||
)
|
||||
async def test_litellm_overhead(model):
|
||||
|
||||
litellm._turn_on_debug()
|
||||
start_time = datetime.now()
|
||||
if model == "openai/self_hosted":
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
api_base="https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
)
|
||||
end_time = datetime.now()
|
||||
total_time_ms = (end_time - start_time).total_seconds() * 1000
|
||||
print(response)
|
||||
print(response._hidden_params)
|
||||
litellm_overhead_ms = response._hidden_params["litellm_overhead_time_ms"]
|
||||
# calculate percent of overhead caused by litellm
|
||||
overhead_percent = litellm_overhead_ms * 100 / total_time_ms
|
||||
print("##########################\n")
|
||||
print("total_time_ms", total_time_ms)
|
||||
print("response litellm_overhead_ms", litellm_overhead_ms)
|
||||
print("litellm overhead_percent {}%".format(overhead_percent))
|
||||
print("##########################\n")
|
||||
assert litellm_overhead_ms > 0
|
||||
assert litellm_overhead_ms < 1000
|
||||
|
||||
# latency overhead should be less than total request time
|
||||
assert litellm_overhead_ms < (end_time - start_time).total_seconds() * 1000
|
||||
|
||||
# latency overhead should be under 40% of total request time
|
||||
assert overhead_percent < 40
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"bedrock/mistral.mistral-7b-instruct-v0:2",
|
||||
"openai/gpt-4o",
|
||||
"bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"openai/self_hosted",
|
||||
],
|
||||
)
|
||||
async def test_litellm_overhead_stream(model):
|
||||
|
||||
litellm._turn_on_debug()
|
||||
start_time = datetime.now()
|
||||
if model == "openai/self_hosted":
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
api_base="https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
stream=True,
|
||||
)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in response:
|
||||
print()
|
||||
|
||||
end_time = datetime.now()
|
||||
total_time_ms = (end_time - start_time).total_seconds() * 1000
|
||||
print(response)
|
||||
print(response._hidden_params)
|
||||
litellm_overhead_ms = response._hidden_params["litellm_overhead_time_ms"]
|
||||
# calculate percent of overhead caused by litellm
|
||||
overhead_percent = litellm_overhead_ms * 100 / total_time_ms
|
||||
print("##########################\n")
|
||||
print("total_time_ms", total_time_ms)
|
||||
print("response litellm_overhead_ms", litellm_overhead_ms)
|
||||
print("litellm overhead_percent {}%".format(overhead_percent))
|
||||
print("##########################\n")
|
||||
assert litellm_overhead_ms > 0
|
||||
assert litellm_overhead_ms < 1000
|
||||
|
||||
# latency overhead should be less than total request time
|
||||
assert litellm_overhead_ms < (end_time - start_time).total_seconds() * 1000
|
||||
|
||||
# latency overhead should be under 40% of total request time
|
||||
assert overhead_percent < 40
|
||||
|
||||
pass
|
@@ -0,0 +1,279 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager
|
||||
from litellm.integrations.langfuse.langfuse_prompt_management import (
|
||||
LangfusePromptManagement,
|
||||
)
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
|
||||
|
||||
# Test fixtures
|
||||
@pytest.fixture
|
||||
def callback_manager():
|
||||
manager = LoggingCallbackManager()
|
||||
# Reset callbacks before each test
|
||||
manager._reset_all_callbacks()
|
||||
return manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_custom_logger():
|
||||
class TestLogger(CustomLogger):
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
return TestLogger()
|
||||
|
||||
|
||||
# Test cases
|
||||
def test_add_string_callback():
|
||||
"""
|
||||
Test adding a string callback to litellm.callbacks - only 1 instance of the string callback should be added
|
||||
"""
|
||||
manager = LoggingCallbackManager()
|
||||
test_callback = "test_callback"
|
||||
|
||||
# Add string callback
|
||||
manager.add_litellm_callback(test_callback)
|
||||
assert test_callback in litellm.callbacks
|
||||
|
||||
# Test duplicate prevention
|
||||
manager.add_litellm_callback(test_callback)
|
||||
assert litellm.callbacks.count(test_callback) == 1
|
||||
|
||||
|
||||
def test_duplicate_langfuse_logger_test():
|
||||
manager = LoggingCallbackManager()
|
||||
for _ in range(10):
|
||||
langfuse_logger = LangfusePromptManagement()
|
||||
manager.add_litellm_success_callback(langfuse_logger)
|
||||
print("litellm.success_callback: ", litellm.success_callback)
|
||||
assert len(litellm.success_callback) == 1
|
||||
|
||||
|
||||
def test_duplicate_multiple_loggers_test():
|
||||
manager = LoggingCallbackManager()
|
||||
for _ in range(10):
|
||||
langfuse_logger = LangfusePromptManagement()
|
||||
otel_logger = OpenTelemetry()
|
||||
manager.add_litellm_success_callback(langfuse_logger)
|
||||
manager.add_litellm_success_callback(otel_logger)
|
||||
print("litellm.success_callback: ", litellm.success_callback)
|
||||
assert len(litellm.success_callback) == 2
|
||||
|
||||
# Check exactly one instance of each logger type
|
||||
langfuse_count = sum(
|
||||
1
|
||||
for callback in litellm.success_callback
|
||||
if isinstance(callback, LangfusePromptManagement)
|
||||
)
|
||||
otel_count = sum(
|
||||
1
|
||||
for callback in litellm.success_callback
|
||||
if isinstance(callback, OpenTelemetry)
|
||||
)
|
||||
|
||||
assert (
|
||||
langfuse_count == 1
|
||||
), "Should have exactly one LangfusePromptManagement instance"
|
||||
assert otel_count == 1, "Should have exactly one OpenTelemetry instance"
|
||||
|
||||
|
||||
def test_add_function_callback():
|
||||
manager = LoggingCallbackManager()
|
||||
|
||||
def test_func(kwargs):
|
||||
pass
|
||||
|
||||
# Add function callback
|
||||
manager.add_litellm_callback(test_func)
|
||||
assert test_func in litellm.callbacks
|
||||
|
||||
# Test duplicate prevention
|
||||
manager.add_litellm_callback(test_func)
|
||||
assert litellm.callbacks.count(test_func) == 1
|
||||
|
||||
|
||||
def test_add_custom_logger(mock_custom_logger):
|
||||
manager = LoggingCallbackManager()
|
||||
|
||||
# Add custom logger
|
||||
manager.add_litellm_callback(mock_custom_logger)
|
||||
assert mock_custom_logger in litellm.callbacks
|
||||
|
||||
|
||||
def test_add_multiple_callback_types(mock_custom_logger):
|
||||
manager = LoggingCallbackManager()
|
||||
|
||||
def test_func(kwargs):
|
||||
pass
|
||||
|
||||
string_callback = "test_callback"
|
||||
|
||||
# Add different types of callbacks
|
||||
manager.add_litellm_callback(string_callback)
|
||||
manager.add_litellm_callback(test_func)
|
||||
manager.add_litellm_callback(mock_custom_logger)
|
||||
|
||||
assert string_callback in litellm.callbacks
|
||||
assert test_func in litellm.callbacks
|
||||
assert mock_custom_logger in litellm.callbacks
|
||||
assert len(litellm.callbacks) == 3
|
||||
|
||||
|
||||
def test_success_failure_callbacks():
|
||||
manager = LoggingCallbackManager()
|
||||
|
||||
success_callback = "success_callback"
|
||||
failure_callback = "failure_callback"
|
||||
|
||||
# Add callbacks
|
||||
manager.add_litellm_success_callback(success_callback)
|
||||
manager.add_litellm_failure_callback(failure_callback)
|
||||
|
||||
assert success_callback in litellm.success_callback
|
||||
assert failure_callback in litellm.failure_callback
|
||||
|
||||
|
||||
def test_async_callbacks():
|
||||
manager = LoggingCallbackManager()
|
||||
|
||||
async_success = "async_success"
|
||||
async_failure = "async_failure"
|
||||
|
||||
# Add async callbacks
|
||||
manager.add_litellm_async_success_callback(async_success)
|
||||
manager.add_litellm_async_failure_callback(async_failure)
|
||||
|
||||
assert async_success in litellm._async_success_callback
|
||||
assert async_failure in litellm._async_failure_callback
|
||||
|
||||
|
||||
def test_remove_callback_from_list_by_object():
|
||||
manager = LoggingCallbackManager()
|
||||
# Reset all callbacks
|
||||
manager._reset_all_callbacks()
|
||||
|
||||
def TestObject():
|
||||
def __init__(self):
|
||||
manager.add_litellm_callback(self.callback)
|
||||
manager.add_litellm_success_callback(self.callback)
|
||||
manager.add_litellm_failure_callback(self.callback)
|
||||
manager.add_litellm_async_success_callback(self.callback)
|
||||
manager.add_litellm_async_failure_callback(self.callback)
|
||||
|
||||
def callback(self):
|
||||
pass
|
||||
|
||||
obj = TestObject()
|
||||
|
||||
manager.remove_callback_from_list_by_object(litellm.callbacks, obj)
|
||||
manager.remove_callback_from_list_by_object(litellm.success_callback, obj)
|
||||
manager.remove_callback_from_list_by_object(litellm.failure_callback, obj)
|
||||
manager.remove_callback_from_list_by_object(litellm._async_success_callback, obj)
|
||||
manager.remove_callback_from_list_by_object(litellm._async_failure_callback, obj)
|
||||
|
||||
# Verify all callback lists are empty
|
||||
assert len(litellm.callbacks) == 0
|
||||
assert len(litellm.success_callback) == 0
|
||||
assert len(litellm.failure_callback) == 0
|
||||
assert len(litellm._async_success_callback) == 0
|
||||
assert len(litellm._async_failure_callback) == 0
|
||||
|
||||
|
||||
|
||||
def test_reset_callbacks(callback_manager):
|
||||
# Add various callbacks
|
||||
callback_manager.add_litellm_callback("test")
|
||||
callback_manager.add_litellm_success_callback("success")
|
||||
callback_manager.add_litellm_failure_callback("failure")
|
||||
callback_manager.add_litellm_async_success_callback("async_success")
|
||||
callback_manager.add_litellm_async_failure_callback("async_failure")
|
||||
|
||||
# Reset all callbacks
|
||||
callback_manager._reset_all_callbacks()
|
||||
|
||||
# Verify all callback lists are empty
|
||||
assert len(litellm.callbacks) == 0
|
||||
assert len(litellm.success_callback) == 0
|
||||
assert len(litellm.failure_callback) == 0
|
||||
assert len(litellm._async_success_callback) == 0
|
||||
assert len(litellm._async_failure_callback) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slack_alerting_callback_registration(callback_manager):
|
||||
"""
|
||||
Test that litellm callbacks are correctly registered for slack alerting
|
||||
when outage_alerts or region_outage_alerts are enabled
|
||||
"""
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
# Mock the async HTTP handler
|
||||
with patch('litellm.integrations.SlackAlerting.slack_alerting.get_async_httpx_client') as mock_http:
|
||||
mock_http.return_value = AsyncMock()
|
||||
|
||||
# Create a fresh ProxyLogging instance
|
||||
proxy_logging = ProxyLogging(user_api_key_cache=DualCache())
|
||||
|
||||
# Test 1: No callbacks should be added when alerting is None
|
||||
proxy_logging.update_values(
|
||||
alerting=None,
|
||||
alert_types=["outage_alerts", "region_outage_alerts"]
|
||||
)
|
||||
assert len(litellm.callbacks) == 0
|
||||
|
||||
# Test 2: Callbacks should be added when slack alerting is enabled with outage alerts
|
||||
proxy_logging.update_values(
|
||||
alerting=["slack"],
|
||||
alert_types=["outage_alerts"]
|
||||
)
|
||||
assert len(litellm.callbacks) == 1
|
||||
assert isinstance(litellm.callbacks[0], SlackAlerting)
|
||||
|
||||
# Test 3: Callbacks should be added when slack alerting is enabled with region outage alerts
|
||||
callback_manager._reset_all_callbacks() # Reset callbacks
|
||||
proxy_logging.update_values(
|
||||
alerting=["slack"],
|
||||
alert_types=["region_outage_alerts"]
|
||||
)
|
||||
assert len(litellm.callbacks) == 1
|
||||
assert isinstance(litellm.callbacks[0], SlackAlerting)
|
||||
|
||||
# Test 4: No callbacks should be added for other alert types
|
||||
callback_manager._reset_all_callbacks() # Reset callbacks
|
||||
proxy_logging.update_values(
|
||||
alerting=["slack"],
|
||||
alert_types=["budget_alerts"] # Some other alert type
|
||||
)
|
||||
assert len(litellm.callbacks) == 0
|
||||
|
||||
# Test 5: Both success and regular callbacks should be added
|
||||
callback_manager._reset_all_callbacks() # Reset callbacks
|
||||
proxy_logging.update_values(
|
||||
alerting=["slack"],
|
||||
alert_types=["outage_alerts"]
|
||||
)
|
||||
assert len(litellm.callbacks) == 1 # Regular callback for outage alerts
|
||||
assert len(litellm.success_callback) == 1 # Success callback for response_taking_too_long
|
||||
assert isinstance(litellm.callbacks[0], SlackAlerting)
|
||||
# Get the method reference for comparison
|
||||
response_taking_too_long_callback = proxy_logging.slack_alerting_instance.response_taking_too_long_callback
|
||||
assert litellm.success_callback[0] == response_taking_too_long_callback
|
||||
|
||||
# Cleanup
|
||||
callback_manager._reset_all_callbacks()
|
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,360 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import json
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
import tempfile
|
||||
from uuid import uuid4
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm.llms.azure.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
|
||||
from litellm.secret_managers.main import (
|
||||
get_secret,
|
||||
_should_read_secret_from_secret_manager,
|
||||
)
|
||||
|
||||
|
||||
def load_vertex_ai_credentials():
|
||||
# Define the path to the vertex_key.json file
|
||||
print("loading vertex ai credentials")
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
vertex_key_path = filepath + "/vertex_key.json"
|
||||
|
||||
# Read the existing content of the file or create an empty dictionary
|
||||
try:
|
||||
with open(vertex_key_path, "r") as file:
|
||||
# Read the file content
|
||||
print("Read vertexai file path")
|
||||
content = file.read()
|
||||
|
||||
# If the file is empty or not valid JSON, create an empty dictionary
|
||||
if not content or not content.strip():
|
||||
service_account_key_data = {}
|
||||
else:
|
||||
# Attempt to load the existing JSON content
|
||||
file.seek(0)
|
||||
service_account_key_data = json.load(file)
|
||||
except FileNotFoundError:
|
||||
# If the file doesn't exist, create an empty dictionary
|
||||
service_account_key_data = {}
|
||||
|
||||
# Update the service_account_key_data with environment variables
|
||||
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
|
||||
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
|
||||
private_key = private_key.replace("\\n", "\n")
|
||||
service_account_key_data["private_key_id"] = private_key_id
|
||||
service_account_key_data["private_key"] = private_key
|
||||
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
|
||||
# Write the updated content to the temporary files
|
||||
json.dump(service_account_key_data, temp_file, indent=2)
|
||||
|
||||
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
||||
|
||||
|
||||
def test_aws_secret_manager():
|
||||
import json
|
||||
|
||||
AWSSecretsManagerV2.load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
|
||||
secret_val = get_secret("litellm_master_key")
|
||||
|
||||
print(f"secret_val: {secret_val}")
|
||||
|
||||
# cast json to dict
|
||||
secret_val = json.loads(secret_val)
|
||||
|
||||
assert secret_val["litellm_master_key"] == "sk-1234"
|
||||
|
||||
|
||||
def redact_oidc_signature(secret_val):
|
||||
# remove the last part of `.` and replace it with "SIGNATURE_REMOVED"
|
||||
return secret_val.split(".")[:-1] + ["SIGNATURE_REMOVED"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("K_SERVICE") is None,
|
||||
reason="Cannot run without being in GCP Cloud Run",
|
||||
)
|
||||
def test_oidc_google():
|
||||
secret_val = get_secret(
|
||||
"oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke"
|
||||
)
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ACTIONS_ID_TOKEN_REQUEST_TOKEN") is None,
|
||||
reason="Cannot run without being in GitHub Actions",
|
||||
)
|
||||
def test_oidc_github():
|
||||
secret_val = get_secret(
|
||||
"oidc/github/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke"
|
||||
)
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CIRCLE_OIDC_TOKEN") is None,
|
||||
reason="Cannot run without being in CircleCI Runner",
|
||||
)
|
||||
def test_oidc_circleci():
|
||||
secret_val = get_secret("oidc/circleci/")
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None,
|
||||
reason="Cannot run without being in CircleCI Runner",
|
||||
)
|
||||
def test_oidc_circleci_v2():
|
||||
secret_val = get_secret(
|
||||
"oidc/circleci_v2/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke"
|
||||
)
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CIRCLE_OIDC_TOKEN") is None,
|
||||
reason="Cannot run without being in CircleCI Runner",
|
||||
)
|
||||
def test_oidc_circleci_with_azure():
|
||||
# TODO: Switch to our own Azure account, currently using ai.moda's account
|
||||
os.environ["AZURE_TENANT_ID"] = "17c0a27a-1246-4aa1-a3b6-d294e80e783c"
|
||||
os.environ["AZURE_CLIENT_ID"] = "4faf5422-b2bd-45e8-a6d7-46543a38acd0"
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(
|
||||
azure_ad_token="oidc/circleci/",
|
||||
azure_client_id=None,
|
||||
azure_tenant_id=None,
|
||||
)
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(azure_ad_token)}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CIRCLE_OIDC_TOKEN") is None,
|
||||
reason="Cannot run without being in CircleCI Runner",
|
||||
)
|
||||
def test_oidc_circle_v1_with_amazon():
|
||||
# The purpose of this test is to get logs using the older v1 of the CircleCI OIDC token
|
||||
|
||||
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
|
||||
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
|
||||
aws_web_identity_token = "oidc/circleci/"
|
||||
|
||||
bllm = BedrockLLM()
|
||||
creds = bllm.get_credentials(
|
||||
aws_region_name="ca-west-1",
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name="assume-v1-session",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CIRCLE_OIDC_TOKEN") is None,
|
||||
reason="Cannot run without being in CircleCI Runner",
|
||||
)
|
||||
def test_oidc_circle_v1_with_amazon_fips():
|
||||
# The purpose of this test is to validate that we can assume a role in a FIPS region
|
||||
|
||||
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
|
||||
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
|
||||
aws_web_identity_token = "oidc/circleci/"
|
||||
|
||||
bllm = BedrockConverseLLM()
|
||||
creds = bllm.get_credentials(
|
||||
aws_region_name="us-west-1",
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name="assume-v1-session-fips",
|
||||
aws_sts_endpoint="https://sts-fips.us-west-1.amazonaws.com",
|
||||
)
|
||||
|
||||
|
||||
def test_oidc_env_variable():
|
||||
# Create a unique environment variable name
|
||||
env_var_name = "OIDC_TEST_PATH_" + uuid4().hex
|
||||
os.environ[env_var_name] = "secret-" + uuid4().hex
|
||||
secret_val = get_secret(f"oidc/env/{env_var_name}")
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||
|
||||
assert secret_val == os.environ[env_var_name]
|
||||
|
||||
# now unset the environment variable
|
||||
del os.environ[env_var_name]
|
||||
|
||||
|
||||
def test_oidc_file():
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode="w+") as temp_file:
|
||||
secret_value = "secret-" + uuid4().hex
|
||||
temp_file.write(secret_value)
|
||||
temp_file.flush()
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
secret_val = get_secret(f"oidc/file/{temp_file_path}")
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||
|
||||
assert secret_val == secret_value
|
||||
|
||||
|
||||
def test_oidc_env_path():
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode="w+") as temp_file:
|
||||
secret_value = "secret-" + uuid4().hex
|
||||
temp_file.write(secret_value)
|
||||
temp_file.flush()
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
# Create a unique environment variable name
|
||||
env_var_name = "OIDC_TEST_PATH_" + uuid4().hex
|
||||
|
||||
# Set the environment variable to the temporary file path
|
||||
os.environ[env_var_name] = temp_file_path
|
||||
|
||||
# Test getting the secret using the environment variable
|
||||
secret_val = get_secret(f"oidc/env_path/{env_var_name}")
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||
|
||||
assert secret_val == secret_value
|
||||
|
||||
del os.environ[env_var_name]
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
def test_google_secret_manager():
|
||||
"""
|
||||
Test that we can get a secret from Google Secret Manager
|
||||
"""
|
||||
os.environ["GOOGLE_SECRET_MANAGER_PROJECT_ID"] = "pathrise-convert-1606954137718"
|
||||
|
||||
from litellm.secret_managers.google_secret_manager import GoogleSecretManager
|
||||
|
||||
load_vertex_ai_credentials()
|
||||
secret_manager = GoogleSecretManager()
|
||||
|
||||
secret_val = secret_manager.get_secret_from_google_secret_manager(
|
||||
secret_name="OPENAI_API_KEY"
|
||||
)
|
||||
print("secret_val: {}".format(secret_val))
|
||||
|
||||
assert (
|
||||
secret_val == "anything"
|
||||
), "did not get expected secret value. expect 'anything', got '{}'".format(
|
||||
secret_val
|
||||
)
|
||||
|
||||
|
||||
def test_google_secret_manager_read_in_memory():
|
||||
"""
|
||||
Test that Google Secret manager returs in memory value when it exists
|
||||
"""
|
||||
from litellm.secret_managers.google_secret_manager import GoogleSecretManager
|
||||
|
||||
load_vertex_ai_credentials()
|
||||
os.environ["GOOGLE_SECRET_MANAGER_PROJECT_ID"] = "pathrise-convert-1606954137718"
|
||||
secret_manager = GoogleSecretManager()
|
||||
secret_manager.cache.cache_dict["UNIQUE_KEY"] = None
|
||||
secret_manager.cache.cache_dict["UNIQUE_KEY_2"] = "lite-llm"
|
||||
|
||||
secret_val = secret_manager.get_secret_from_google_secret_manager(
|
||||
secret_name="UNIQUE_KEY"
|
||||
)
|
||||
print("secret_val: {}".format(secret_val))
|
||||
assert secret_val == None
|
||||
|
||||
secret_val = secret_manager.get_secret_from_google_secret_manager(
|
||||
secret_name="UNIQUE_KEY_2"
|
||||
)
|
||||
print("secret_val: {}".format(secret_val))
|
||||
assert secret_val == "lite-llm"
|
||||
|
||||
|
||||
def test_should_read_secret_from_secret_manager():
|
||||
"""
|
||||
Test that _should_read_secret_from_secret_manager returns correct values based on access mode
|
||||
"""
|
||||
from litellm.types.secret_managers.main import KeyManagementSettings
|
||||
|
||||
# Test when secret manager client is None
|
||||
litellm.secret_manager_client = None
|
||||
litellm._key_management_settings = KeyManagementSettings()
|
||||
assert _should_read_secret_from_secret_manager() is False
|
||||
|
||||
# Test with secret manager client and read_only access
|
||||
litellm.secret_manager_client = "dummy_client"
|
||||
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
|
||||
assert _should_read_secret_from_secret_manager() is True
|
||||
|
||||
# Test with secret manager client and read_and_write access
|
||||
litellm._key_management_settings = KeyManagementSettings(
|
||||
access_mode="read_and_write"
|
||||
)
|
||||
assert _should_read_secret_from_secret_manager() is True
|
||||
|
||||
# Test with secret manager client and write_only access
|
||||
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
|
||||
assert _should_read_secret_from_secret_manager() is False
|
||||
|
||||
# Reset global variables
|
||||
litellm.secret_manager_client = None
|
||||
litellm._key_management_settings = KeyManagementSettings()
|
||||
|
||||
|
||||
def test_get_secret_with_access_mode():
|
||||
"""
|
||||
Test that get_secret respects access mode settings
|
||||
"""
|
||||
from litellm.types.secret_managers.main import KeyManagementSettings
|
||||
|
||||
# Set up test environment
|
||||
test_secret_name = "TEST_SECRET_KEY"
|
||||
test_secret_value = "test_secret_value"
|
||||
os.environ[test_secret_name] = test_secret_value
|
||||
|
||||
# Test with write_only access (should read from os.environ)
|
||||
litellm.secret_manager_client = "dummy_client"
|
||||
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
|
||||
assert get_secret(test_secret_name) == test_secret_value
|
||||
|
||||
# Test with no KeyManagementSettings but secret_manager_client set
|
||||
litellm.secret_manager_client = "dummy_client"
|
||||
litellm._key_management_settings = KeyManagementSettings()
|
||||
assert _should_read_secret_from_secret_manager() is True
|
||||
|
||||
# Test with read_only access
|
||||
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
|
||||
assert _should_read_secret_from_secret_manager() is True
|
||||
|
||||
# Test with read_and_write access
|
||||
litellm._key_management_settings = KeyManagementSettings(
|
||||
access_mode="read_and_write"
|
||||
)
|
||||
assert _should_read_secret_from_secret_manager() is True
|
||||
|
||||
# Reset global variables
|
||||
litellm.secret_manager_client = None
|
||||
litellm._key_management_settings = KeyManagementSettings()
|
||||
del os.environ[test_secret_name]
|
2326
Development/litellm/tests/litellm_utils_tests/test_utils.py
Normal file
2326
Development/litellm/tests/litellm_utils_tests/test_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
from litellm.utils import validate_chat_completion_tool_choice
|
||||
|
||||
|
||||
def test_validate_tool_choice_none():
|
||||
"""Test that None is returned as-is."""
|
||||
result = validate_chat_completion_tool_choice(None)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_validate_tool_choice_string():
|
||||
"""Test that string values are returned as-is."""
|
||||
assert validate_chat_completion_tool_choice("auto") == "auto"
|
||||
assert validate_chat_completion_tool_choice("none") == "none"
|
||||
assert validate_chat_completion_tool_choice("required") == "required"
|
||||
|
||||
|
||||
def test_validate_tool_choice_standard_dict():
|
||||
"""Test standard OpenAI format with function."""
|
||||
tool_choice = {"type": "function", "function": {"name": "my_function"}}
|
||||
result = validate_chat_completion_tool_choice(tool_choice)
|
||||
assert result == tool_choice
|
||||
|
||||
|
||||
def test_validate_tool_choice_cursor_format():
|
||||
"""Test Cursor IDE format: {"type": "auto"} -> {"type": "auto"}."""
|
||||
assert validate_chat_completion_tool_choice({"type": "auto"}) == {"type": "auto"}
|
||||
assert validate_chat_completion_tool_choice({"type": "none"}) == {"type": "none"}
|
||||
assert validate_chat_completion_tool_choice({"type": "required"}) == {"type": "required"}
|
||||
|
||||
|
||||
def test_validate_tool_choice_invalid_dict():
|
||||
"""Test that invalid dict formats raise exceptions."""
|
||||
# Missing both type and function
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
validate_chat_completion_tool_choice({})
|
||||
assert "Invalid tool choice" in str(exc_info.value)
|
||||
|
||||
# Invalid type value
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
validate_chat_completion_tool_choice({"type": "invalid"})
|
||||
assert "Invalid tool choice" in str(exc_info.value)
|
||||
|
||||
# Has type but missing function when type is "function"
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
validate_chat_completion_tool_choice({"type": "function"})
|
||||
assert "Invalid tool choice" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_validate_tool_choice_invalid_type():
|
||||
"""Test that invalid types raise exceptions."""
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
validate_chat_completion_tool_choice(123)
|
||||
assert "Got=<class 'int'>" in str(exc_info.value)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
validate_chat_completion_tool_choice([])
|
||||
assert "Got=<class 'list'>" in str(exc_info.value)
|
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"type": "service_account",
|
||||
"project_id": "pathrise-convert-1606954137718",
|
||||
"private_key_id": "",
|
||||
"private_key": "",
|
||||
"client_email": "ci-cd-723@pathrise-convert-1606954137718.iam.gserviceaccount.com",
|
||||
"client_id": "109577393201924326488",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/ci-cd-723%40pathrise-convert-1606954137718.iam.gserviceaccount.com",
|
||||
"universe_domain": "googleapis.com"
|
||||
}
|
Reference in New Issue
Block a user