Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,216 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Adds the grandparent directory to sys.path to allow importing project modules
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.secret_managers.get_azure_ad_token_provider import (
|
||||
get_azure_ad_token_provider,
|
||||
)
|
||||
|
||||
|
||||
class TestGetAzureAdTokenProvider:
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"AZURE_CLIENT_ID": "test-client-id",
|
||||
"AZURE_CLIENT_SECRET": "test-client-secret",
|
||||
"AZURE_TENANT_ID": "test-tenant-id",
|
||||
"AZURE_SCOPE": "https://cognitiveservices.azure.com/.default",
|
||||
"AZURE_CREDENTIAL": "ClientSecretCredential",
|
||||
},
|
||||
)
|
||||
@patch("azure.identity.get_bearer_token_provider")
|
||||
@patch("azure.identity.ClientSecretCredential")
|
||||
def test_get_azure_ad_token_provider_client_secret_credential(
|
||||
self, mock_client_secret_credential, mock_get_bearer_token_provider
|
||||
):
|
||||
"""Test get_azure_ad_token_provider with ClientSecretCredential."""
|
||||
# Mock the Azure identity credential instance
|
||||
mock_credential_instance = MagicMock()
|
||||
mock_client_secret_credential.return_value = mock_credential_instance
|
||||
|
||||
# Mock the bearer token provider
|
||||
mock_token_provider = MagicMock(return_value="mock-token")
|
||||
mock_get_bearer_token_provider.return_value = mock_token_provider
|
||||
|
||||
# Call the function
|
||||
result = get_azure_ad_token_provider()
|
||||
|
||||
# Assertions
|
||||
assert callable(result)
|
||||
mock_client_secret_credential.assert_called_once_with(
|
||||
client_id="test-client-id",
|
||||
client_secret="test-client-secret",
|
||||
tenant_id="test-tenant-id",
|
||||
)
|
||||
mock_get_bearer_token_provider.assert_called_once_with(
|
||||
mock_credential_instance, "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
|
||||
# Test that the returned callable works
|
||||
token = result()
|
||||
assert token == "mock-token"
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"AZURE_CLIENT_ID": "test-client-id",
|
||||
"AZURE_SCOPE": "https://cognitiveservices.azure.com/.default",
|
||||
"AZURE_CREDENTIAL": "ManagedIdentityCredential",
|
||||
},
|
||||
)
|
||||
@patch("azure.identity.get_bearer_token_provider")
|
||||
@patch("azure.identity.ManagedIdentityCredential")
|
||||
def test_get_azure_ad_token_provider_managed_identity_credential(
|
||||
self, mock_managed_identity_credential, mock_get_bearer_token_provider
|
||||
):
|
||||
"""Test get_azure_ad_token_provider with ManagedIdentityCredential."""
|
||||
# Mock the Azure identity credential instance
|
||||
mock_credential_instance = MagicMock()
|
||||
mock_managed_identity_credential.return_value = mock_credential_instance
|
||||
|
||||
# Mock the bearer token provider
|
||||
mock_token_provider = MagicMock(return_value="mock-managed-identity-token")
|
||||
mock_get_bearer_token_provider.return_value = mock_token_provider
|
||||
|
||||
# Call the function
|
||||
result = get_azure_ad_token_provider()
|
||||
|
||||
# Assertions
|
||||
assert callable(result)
|
||||
mock_managed_identity_credential.assert_called_once_with(
|
||||
client_id="test-client-id"
|
||||
)
|
||||
mock_get_bearer_token_provider.assert_called_once_with(
|
||||
mock_credential_instance, "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
|
||||
# Test that the returned callable works
|
||||
token = result()
|
||||
assert token == "mock-managed-identity-token"
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"AZURE_CLIENT_ID": "test-client-id",
|
||||
"AZURE_TENANT_ID": "test-tenant-id",
|
||||
"AZURE_CERTIFICATE_PATH": "/path/to/cert.pem",
|
||||
"AZURE_SCOPE": "https://cognitiveservices.azure.com/.default",
|
||||
"AZURE_CREDENTIAL": "CertificateCredential",
|
||||
},
|
||||
)
|
||||
@patch("azure.identity.get_bearer_token_provider")
|
||||
@patch("azure.identity.CertificateCredential")
|
||||
def test_get_azure_ad_token_provider_certificate_credential(
|
||||
self, mock_certificate_credential, mock_get_bearer_token_provider
|
||||
):
|
||||
"""Test get_azure_ad_token_provider with CertificateCredential."""
|
||||
# Mock the Azure identity credential instance
|
||||
mock_credential_instance = MagicMock()
|
||||
mock_certificate_credential.return_value = mock_credential_instance
|
||||
|
||||
# Mock the bearer token provider
|
||||
mock_token_provider = MagicMock(return_value="mock-certificate-token")
|
||||
mock_get_bearer_token_provider.return_value = mock_token_provider
|
||||
|
||||
# Call the function
|
||||
result = get_azure_ad_token_provider()
|
||||
|
||||
# Assertions
|
||||
assert callable(result)
|
||||
mock_certificate_credential.assert_called_once_with(
|
||||
client_id="test-client-id",
|
||||
tenant_id="test-tenant-id",
|
||||
certificate_path="/path/to/cert.pem",
|
||||
)
|
||||
mock_get_bearer_token_provider.assert_called_once_with(
|
||||
mock_credential_instance, "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
|
||||
# Test that the returned callable works
|
||||
token = result()
|
||||
assert token == "mock-certificate-token"
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"AZURE_CLIENT_ID": "test-client-id",
|
||||
"AZURE_TENANT_ID": "test-tenant-id",
|
||||
"AZURE_CERTIFICATE_PATH": "/path/to/cert.pem",
|
||||
"AZURE_SCOPE": "https://cognitiveservices.azure.com/.default",
|
||||
"AZURE_CREDENTIAL": "CertificateCredential",
|
||||
"AZURE_CERTIFICATE_PASSWORD": "pwd4cert.pem",
|
||||
},
|
||||
)
|
||||
@patch("azure.identity.get_bearer_token_provider")
|
||||
@patch("azure.identity.CertificateCredential")
|
||||
def test_get_azure_ad_token_provider_password_protected_certificate_credential(
|
||||
self, mock_certificate_credential, mock_get_bearer_token_provider
|
||||
):
|
||||
"""Test get_azure_ad_token_provider with password protected certificate in CertificateCredential."""
|
||||
# Mock the Azure identity credential instance
|
||||
mock_credential_instance = MagicMock()
|
||||
mock_certificate_credential.return_value = mock_credential_instance
|
||||
|
||||
# Mock the bearer token provider
|
||||
mock_token_provider = MagicMock(return_value="mock-certificate-token")
|
||||
mock_get_bearer_token_provider.return_value = mock_token_provider
|
||||
|
||||
# Call the function
|
||||
result = get_azure_ad_token_provider()
|
||||
|
||||
# Assertions
|
||||
assert callable(result)
|
||||
mock_certificate_credential.assert_called_once_with(
|
||||
client_id="test-client-id",
|
||||
tenant_id="test-tenant-id",
|
||||
certificate_path="/path/to/cert.pem",
|
||||
password="pwd4cert.pem",
|
||||
)
|
||||
mock_get_bearer_token_provider.assert_called_once_with(
|
||||
mock_credential_instance, "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
|
||||
# Test that the returned callable works
|
||||
token = result()
|
||||
assert token == "mock-certificate-token"
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"AZURE_CREDENTIAL": "DefaultAzureCredential",
|
||||
},
|
||||
)
|
||||
@patch("azure.identity.get_bearer_token_provider")
|
||||
@patch("azure.identity.DefaultAzureCredential")
|
||||
def test_get_azure_ad_token_provider_default_azure_credential(
|
||||
self, mock_certificate_credential, mock_get_bearer_token_provider
|
||||
):
|
||||
"""Test get_azure_ad_token_provider with DefaultAzureCredential."""
|
||||
# Mock the Azure identity credential instance
|
||||
mock_credential_instance = MagicMock()
|
||||
mock_certificate_credential.return_value = mock_credential_instance
|
||||
|
||||
# Mock the bearer token provider
|
||||
mock_token_provider = MagicMock(return_value="mock-certificate-token")
|
||||
mock_get_bearer_token_provider.return_value = mock_token_provider
|
||||
|
||||
# Call the function
|
||||
result = get_azure_ad_token_provider()
|
||||
|
||||
# Assertions
|
||||
assert callable(result)
|
||||
mock_certificate_credential.assert_called_once_with()
|
||||
mock_get_bearer_token_provider.assert_called_once_with(
|
||||
mock_credential_instance, "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
|
||||
# Test that the returned callable works
|
||||
token = result()
|
||||
assert token == "mock-certificate-token"
|
@@ -0,0 +1,197 @@
|
||||
import logging
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
# Set up logging for debugging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mock HTTPHandler and oidc_cache
|
||||
class MockHTTPHandler:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
self.status_code = 200
|
||||
self.text = "mocked_token"
|
||||
self.json_data = {"value": "mocked_token"}
|
||||
|
||||
def get(self, url, params=None, headers=None):
|
||||
# Store params for audience verification
|
||||
self.last_params = params
|
||||
logger.debug(
|
||||
f"MockHTTPHandler.get called with url={url}, params={params}, headers={headers}"
|
||||
)
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = self.status_code
|
||||
mock_response.text = self.text
|
||||
mock_response.json.return_value = self.json_data
|
||||
return mock_response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oidc_cache():
|
||||
cache = Mock()
|
||||
cache.get_cache.return_value = None
|
||||
cache.set_cache = Mock()
|
||||
return cache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
yield os.environ
|
||||
|
||||
|
||||
@patch("litellm.secret_managers.main.oidc_cache")
|
||||
@patch("litellm.secret_managers.main.HTTPHandler")
|
||||
def test_oidc_google_success(mock_http_handler, mock_oidc_cache):
|
||||
mock_oidc_cache.get_cache.return_value = None
|
||||
mock_handler = MockHTTPHandler(timeout=600.0)
|
||||
mock_http_handler.return_value = mock_handler
|
||||
secret_name = "oidc/google/[invalid url, do not cite]"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "mocked_token"
|
||||
assert mock_handler.last_params == {"audience": "[invalid url, do not cite]"}
|
||||
mock_oidc_cache.set_cache.assert_called_once_with(
|
||||
key=secret_name, value="mocked_token", ttl=3540
|
||||
)
|
||||
|
||||
|
||||
@patch("litellm.secret_managers.main.oidc_cache")
|
||||
def test_oidc_google_cached(mock_oidc_cache):
|
||||
mock_oidc_cache.get_cache.return_value = "cached_token"
|
||||
|
||||
secret_name = "oidc/google/[invalid url, do not cite]"
|
||||
with patch("litellm.HTTPHandler") as mock_http:
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "cached_token", f"Expected cached token, got {result}"
|
||||
mock_oidc_cache.get_cache.assert_called_with(key=secret_name)
|
||||
mock_http.assert_not_called()
|
||||
|
||||
|
||||
def test_oidc_google_failure(mock_oidc_cache):
|
||||
mock_handler = MockHTTPHandler(timeout=600.0)
|
||||
mock_handler.status_code = 400
|
||||
|
||||
with patch("litellm.secret_managers.main.HTTPHandler", return_value=mock_handler):
|
||||
mock_oidc_cache.get_cache.return_value = None
|
||||
secret_name = "oidc/google/https://example.com/api"
|
||||
|
||||
with pytest.raises(ValueError, match="Google OIDC provider failed"):
|
||||
get_secret(secret_name)
|
||||
|
||||
|
||||
def test_oidc_circleci_success(monkeypatch):
|
||||
monkeypatch.setenv("CIRCLE_OIDC_TOKEN", "circleci_token")
|
||||
|
||||
secret_name = "oidc/circleci/test-audience"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "circleci_token"
|
||||
|
||||
|
||||
def test_oidc_circleci_failure(monkeypatch):
|
||||
monkeypatch.delenv("CIRCLE_OIDC_TOKEN", raising=False)
|
||||
secret_name = "oidc/circleci/test-audience"
|
||||
|
||||
with pytest.raises(ValueError, match="CIRCLE_OIDC_TOKEN not found in environment"):
|
||||
get_secret(secret_name)
|
||||
|
||||
|
||||
@patch("litellm.secret_managers.main.oidc_cache")
|
||||
@patch("litellm.secret_managers.main.HTTPHandler")
|
||||
def test_oidc_github_success(mock_http_handler, mock_oidc_cache, mock_env):
|
||||
mock_env["ACTIONS_ID_TOKEN_REQUEST_URL"] = "https://github.com/token"
|
||||
mock_env["ACTIONS_ID_TOKEN_REQUEST_TOKEN"] = "github_token"
|
||||
mock_oidc_cache.get_cache.return_value = None
|
||||
mock_handler = MockHTTPHandler(timeout=600.0)
|
||||
mock_http_handler.return_value = mock_handler
|
||||
|
||||
secret_name = "oidc/github/github-audience"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "mocked_token", f"Expected token 'mocked_token', got {result}"
|
||||
assert mock_handler.last_params == {"audience": "github-audience"}
|
||||
logger.debug(f"set_cache call args: {mock_oidc_cache.set_cache.call_args}")
|
||||
mock_oidc_cache.set_cache.assert_called_once()
|
||||
mock_oidc_cache.set_cache.assert_called_with(
|
||||
key=secret_name, value="mocked_token", ttl=295
|
||||
)
|
||||
|
||||
|
||||
def test_oidc_github_missing_env():
|
||||
secret_name = "oidc/github/github-audience"
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment",
|
||||
):
|
||||
get_secret(secret_name)
|
||||
|
||||
|
||||
def test_oidc_azure_file_success(mock_env, tmp_path):
|
||||
token_file = tmp_path / "token.txt"
|
||||
token_file.write_text("azure_token")
|
||||
mock_env["AZURE_FEDERATED_TOKEN_FILE"] = str(token_file)
|
||||
|
||||
secret_name = "oidc/azure/azure-audience"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "azure_token"
|
||||
|
||||
|
||||
@patch("litellm.secret_managers.main.get_azure_ad_token_provider")
|
||||
def test_oidc_azure_ad_token_success(mock_get_azure_ad_token_provider):
|
||||
mock_token_provider = Mock(return_value="azure_ad_token")
|
||||
mock_get_azure_ad_token_provider.return_value = mock_token_provider
|
||||
secret_name = "oidc/azure/api://azure-audience"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "azure_ad_token"
|
||||
mock_get_azure_ad_token_provider.assert_called_once_with(
|
||||
azure_scope="api://azure-audience"
|
||||
)
|
||||
mock_token_provider.assert_called_once_with()
|
||||
|
||||
|
||||
def test_oidc_file_success(tmp_path):
|
||||
token_file = tmp_path / "token.txt"
|
||||
token_file.write_text("file_token")
|
||||
|
||||
secret_name = f"oidc/file/{token_file}"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "file_token"
|
||||
|
||||
|
||||
def test_oidc_env_success(mock_env):
|
||||
mock_env["CUSTOM_TOKEN"] = "env_token"
|
||||
|
||||
secret_name = "oidc/env/CUSTOM_TOKEN"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "env_token"
|
||||
|
||||
|
||||
def test_oidc_env_path_success(mock_env, tmp_path):
|
||||
token_file = tmp_path / "token.txt"
|
||||
token_file.write_text("env_path_token")
|
||||
mock_env["TOKEN_PATH"] = str(token_file)
|
||||
|
||||
secret_name = "oidc/env_path/TOKEN_PATH"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "env_path_token"
|
||||
|
||||
|
||||
def test_unsupported_oidc_provider():
|
||||
secret_name = "oidc/unsupported/unsupported-audience"
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported OIDC provider"):
|
||||
get_secret(secret_name)
|
Reference in New Issue
Block a user