Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,191 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.llms.github_copilot.authenticator import Authenticator
|
||||
from litellm.llms.github_copilot.common_utils import (
|
||||
APIKeyExpiredError,
|
||||
GetAccessTokenError,
|
||||
GetAPIKeyError,
|
||||
GetDeviceCodeError,
|
||||
RefreshAPIKeyError,
|
||||
)
|
||||
|
||||
|
||||
class TestGitHubCopilotAuthenticator:
|
||||
@pytest.fixture
|
||||
def authenticator(self):
|
||||
with patch("os.path.exists", return_value=False), patch("os.makedirs") as mock_makedirs:
|
||||
auth = Authenticator()
|
||||
mock_makedirs.assert_called_once()
|
||||
return auth
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_client(self):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
return mock_client, mock_response
|
||||
|
||||
def test_init(self):
|
||||
"""Test the initialization of the authenticator."""
|
||||
with patch("os.path.exists", return_value=False), patch("os.makedirs") as mock_makedirs:
|
||||
auth = Authenticator()
|
||||
assert auth.token_dir.endswith("/github_copilot")
|
||||
assert auth.access_token_file.endswith("/access-token")
|
||||
assert auth.api_key_file.endswith("/api-key.json")
|
||||
mock_makedirs.assert_called_once()
|
||||
|
||||
def test_ensure_token_dir(self):
|
||||
"""Test that the token directory is created if it doesn't exist."""
|
||||
with patch("os.path.exists", return_value=False), patch("os.makedirs") as mock_makedirs:
|
||||
auth = Authenticator()
|
||||
mock_makedirs.assert_called_once_with(auth.token_dir, exist_ok=True)
|
||||
|
||||
def test_get_github_headers(self, authenticator):
|
||||
"""Test that GitHub headers are correctly generated."""
|
||||
headers = authenticator._get_github_headers()
|
||||
assert "accept" in headers
|
||||
assert "editor-version" in headers
|
||||
assert "user-agent" in headers
|
||||
assert "content-type" in headers
|
||||
|
||||
headers_with_token = authenticator._get_github_headers("test-token")
|
||||
assert headers_with_token["authorization"] == "token test-token"
|
||||
|
||||
def test_get_access_token_from_file(self, authenticator):
|
||||
"""Test retrieving an access token from a file."""
|
||||
mock_token = "mock-access-token"
|
||||
|
||||
with patch("builtins.open", mock_open(read_data=mock_token)):
|
||||
token = authenticator.get_access_token()
|
||||
assert token == mock_token
|
||||
|
||||
def test_get_access_token_login(self, authenticator):
|
||||
"""Test logging in to get an access token."""
|
||||
mock_token = "mock-access-token"
|
||||
|
||||
with patch.object(authenticator, "_login", return_value=mock_token), \
|
||||
patch("builtins.open", mock_open()), \
|
||||
patch("builtins.open", side_effect=IOError) as mock_read:
|
||||
token = authenticator.get_access_token()
|
||||
assert token == mock_token
|
||||
authenticator._login.assert_called_once()
|
||||
|
||||
def test_get_access_token_failure(self, authenticator):
|
||||
"""Test that an exception is raised after multiple login failures."""
|
||||
with patch.object(authenticator, "_login", side_effect=GetDeviceCodeError(message="Test error", status_code=400)), \
|
||||
patch("builtins.open", side_effect=IOError):
|
||||
with pytest.raises(GetAccessTokenError):
|
||||
authenticator.get_access_token()
|
||||
assert authenticator._login.call_count == 3
|
||||
|
||||
def test_get_api_key_from_file(self, authenticator):
|
||||
"""Test retrieving an API key from a file."""
|
||||
future_time = (datetime.now() + timedelta(hours=1)).timestamp()
|
||||
mock_api_key_data = json.dumps({"token": "mock-api-key", "expires_at": future_time})
|
||||
|
||||
with patch("builtins.open", mock_open(read_data=mock_api_key_data)):
|
||||
api_key = authenticator.get_api_key()
|
||||
assert api_key == "mock-api-key"
|
||||
|
||||
def test_get_api_key_expired(self, authenticator):
|
||||
"""Test refreshing an expired API key."""
|
||||
past_time = (datetime.now() - timedelta(hours=1)).timestamp()
|
||||
mock_expired_data = json.dumps({"token": "expired-api-key", "expires_at": past_time})
|
||||
mock_new_data = {"token": "new-api-key", "expires_at": (datetime.now() + timedelta(hours=1)).timestamp()}
|
||||
|
||||
with patch("builtins.open", mock_open(read_data=mock_expired_data)), \
|
||||
patch.object(authenticator, "_refresh_api_key", return_value=mock_new_data), \
|
||||
patch("json.dump") as mock_json_dump:
|
||||
api_key = authenticator.get_api_key()
|
||||
assert api_key == "new-api-key"
|
||||
authenticator._refresh_api_key.assert_called_once()
|
||||
|
||||
def test_refresh_api_key(self, authenticator, mock_http_client):
|
||||
"""Test refreshing an API key."""
|
||||
mock_client, mock_response = mock_http_client
|
||||
mock_token = "mock-access-token"
|
||||
mock_api_key_data = {"token": "new-api-key", "expires_at": 12345}
|
||||
|
||||
with patch.object(authenticator, "get_access_token", return_value=mock_token), \
|
||||
patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \
|
||||
patch.object(mock_response, "json", return_value=mock_api_key_data):
|
||||
result = authenticator._refresh_api_key()
|
||||
assert result == mock_api_key_data
|
||||
mock_client.get.assert_called_once()
|
||||
authenticator.get_access_token.assert_called_once()
|
||||
|
||||
def test_refresh_api_key_failure(self, authenticator, mock_http_client):
|
||||
"""Test failure to refresh an API key."""
|
||||
mock_client, mock_response = mock_http_client
|
||||
mock_token = "mock-access-token"
|
||||
|
||||
with patch.object(authenticator, "get_access_token", return_value=mock_token), \
|
||||
patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \
|
||||
patch.object(mock_response, "json", return_value={}):
|
||||
with pytest.raises(RefreshAPIKeyError):
|
||||
authenticator._refresh_api_key()
|
||||
assert mock_client.get.call_count == 3
|
||||
|
||||
def test_get_device_code(self, authenticator, mock_http_client):
|
||||
"""Test getting a device code."""
|
||||
mock_client, mock_response = mock_http_client
|
||||
mock_device_code_data = {
|
||||
"device_code": "mock-device-code",
|
||||
"user_code": "ABCD-EFGH",
|
||||
"verification_uri": "https://github.com/login/device"
|
||||
}
|
||||
|
||||
with patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \
|
||||
patch.object(mock_response, "json", return_value=mock_device_code_data):
|
||||
result = authenticator._get_device_code()
|
||||
assert result == mock_device_code_data
|
||||
mock_client.post.assert_called_once()
|
||||
|
||||
def test_poll_for_access_token(self, authenticator, mock_http_client):
|
||||
"""Test polling for an access token."""
|
||||
mock_client, mock_response = mock_http_client
|
||||
mock_token_data = {"access_token": "mock-access-token"}
|
||||
|
||||
with patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \
|
||||
patch.object(mock_response, "json", return_value=mock_token_data), \
|
||||
patch("time.sleep"):
|
||||
result = authenticator._poll_for_access_token("mock-device-code")
|
||||
assert result == "mock-access-token"
|
||||
mock_client.post.assert_called_once()
|
||||
|
||||
def test_login(self, authenticator):
|
||||
"""Test the login process."""
|
||||
mock_device_code_data = {
|
||||
"device_code": "mock-device-code",
|
||||
"user_code": "ABCD-EFGH",
|
||||
"verification_uri": "https://github.com/login/device"
|
||||
}
|
||||
mock_token = "mock-access-token"
|
||||
|
||||
with patch.object(authenticator, "_get_device_code", return_value=mock_device_code_data), \
|
||||
patch.object(authenticator, "_poll_for_access_token", return_value=mock_token), \
|
||||
patch("builtins.print") as mock_print:
|
||||
result = authenticator._login()
|
||||
assert result == mock_token
|
||||
authenticator._get_device_code.assert_called_once()
|
||||
authenticator._poll_for_access_token.assert_called_once_with("mock-device-code")
|
||||
mock_print.assert_called_once()
|
||||
|
||||
def test_get_api_base_from_file(self, authenticator):
|
||||
"""Test retrieving the API base endpoint from a file."""
|
||||
mock_api_key_data = json.dumps({
|
||||
"token": "mock-api-key",
|
||||
"expires_at": (datetime.now() + timedelta(hours=1)).timestamp(),
|
||||
"endpoints": {"api": "https://api.enterprise.githubcopilot.com"}
|
||||
})
|
||||
with patch("builtins.open", mock_open(read_data=mock_api_key_data)):
|
||||
api_base = authenticator.get_api_base()
|
||||
assert api_base == "https://api.enterprise.githubcopilot.com"
|
@@ -0,0 +1,364 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from typing import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from respx import MockRouter
|
||||
|
||||
import litellm
|
||||
|
||||
# Import at the top to make the patch work correctly
|
||||
import litellm.llms.github_copilot.chat.transformation
|
||||
from litellm import Choices, Message, ModelResponse, Usage, acompletion, completion
|
||||
from litellm.exceptions import AuthenticationError
|
||||
from litellm.llms.github_copilot.authenticator import Authenticator
|
||||
from litellm.llms.github_copilot.chat.transformation import GithubCopilotConfig
|
||||
from litellm.llms.github_copilot.common_utils import (
|
||||
APIKeyExpiredError,
|
||||
GetAccessTokenError,
|
||||
GetAPIKeyError,
|
||||
GetDeviceCodeError,
|
||||
RefreshAPIKeyError,
|
||||
)
|
||||
|
||||
|
||||
def test_github_copilot_config_get_openai_compatible_provider_info():
|
||||
"""Test the GitHub Copilot configuration provider info retrieval."""
|
||||
|
||||
config = GithubCopilotConfig()
|
||||
|
||||
# Mock the authenticator to avoid actual API calls
|
||||
mock_api_key = "gh.test-key-123456789"
|
||||
config.authenticator = MagicMock()
|
||||
config.authenticator.get_api_key.return_value = mock_api_key
|
||||
# Test with dynamic endpoint
|
||||
config.authenticator.get_api_base.return_value = "https://api.enterprise.githubcopilot.com"
|
||||
|
||||
# Test with default values
|
||||
model = "github_copilot/gpt-4"
|
||||
(
|
||||
api_base,
|
||||
dynamic_api_key,
|
||||
custom_llm_provider,
|
||||
) = config._get_openai_compatible_provider_info(
|
||||
model=model,
|
||||
api_base=None,
|
||||
api_key=None,
|
||||
custom_llm_provider="github_copilot",
|
||||
)
|
||||
|
||||
assert api_base == "https://api.enterprise.githubcopilot.com"
|
||||
assert dynamic_api_key == mock_api_key
|
||||
assert custom_llm_provider == "github_copilot"
|
||||
|
||||
# Test fallback to default if no dynamic endpoint
|
||||
config.authenticator.get_api_base.return_value = None
|
||||
(
|
||||
api_base,
|
||||
dynamic_api_key,
|
||||
custom_llm_provider,
|
||||
) = config._get_openai_compatible_provider_info(
|
||||
model=model,
|
||||
api_base=None,
|
||||
api_key=None,
|
||||
custom_llm_provider="github_copilot",
|
||||
)
|
||||
assert api_base == "https://api.githubcopilot.com/"
|
||||
|
||||
# Test with authentication failure
|
||||
config.authenticator.get_api_key.side_effect = GetAPIKeyError(
|
||||
message="Failed to get API key",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
with pytest.raises(AuthenticationError) as excinfo:
|
||||
config._get_openai_compatible_provider_info(
|
||||
model=model,
|
||||
api_base=None,
|
||||
api_key=None,
|
||||
custom_llm_provider="github_copilot",
|
||||
)
|
||||
|
||||
assert "Failed to get API key" in str(excinfo.value)
|
||||
|
||||
|
||||
@patch("litellm.llms.github_copilot.authenticator.Authenticator.get_api_key")
|
||||
@patch("litellm.llms.openai.openai.OpenAIChatCompletion.completion")
|
||||
def test_completion_github_copilot_mock_response(mock_completion, mock_get_api_key):
|
||||
"""Test the completion function with GitHub Copilot provider."""
|
||||
|
||||
# Mock the API key return value
|
||||
mock_api_key = "gh.test-key-123456789"
|
||||
mock_get_api_key.return_value = mock_api_key
|
||||
|
||||
# Mock completion response
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Hello, I'm GitHub Copilot!"
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
# Test non-streaming completion
|
||||
messages = [
|
||||
{"role": "system", "content": "You're GitHub Copilot, an AI assistant."},
|
||||
{"role": "user", "content": "Hello, who are you?"},
|
||||
]
|
||||
|
||||
# Create a properly formatted headers dictionary
|
||||
headers = {
|
||||
"editor-version": "Neovim/0.9.0",
|
||||
"Copilot-Integration-Id": "vscode-chat",
|
||||
}
|
||||
|
||||
response = completion(
|
||||
model="github_copilot/gpt-4",
|
||||
messages=messages,
|
||||
extra_headers=headers,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
|
||||
# Verify the get_api_key call was made (can be called multiple times)
|
||||
assert mock_get_api_key.call_count >= 1
|
||||
|
||||
# Verify the completion call was made with the expected params
|
||||
mock_completion.assert_called_once()
|
||||
args, kwargs = mock_completion.call_args
|
||||
|
||||
# Check that the proper authorization header is set
|
||||
assert "headers" in kwargs
|
||||
# Check that the model name is correctly formatted
|
||||
assert (
|
||||
kwargs.get("model") == "gpt-4"
|
||||
) # Model name should be without provider prefix
|
||||
assert kwargs.get("messages") == messages
|
||||
|
||||
|
||||
def test_transform_messages_disable_copilot_system_to_assistant(monkeypatch):
|
||||
"""Test that system messages are converted to assistant unless disable_copilot_system_to_assistant is True."""
|
||||
import litellm
|
||||
from litellm.llms.github_copilot.chat.transformation import GithubCopilotConfig
|
||||
|
||||
# Save original value
|
||||
original_flag = litellm.disable_copilot_system_to_assistant
|
||||
try:
|
||||
# Case 1: Flag is False (default, conversion happens)
|
||||
litellm.disable_copilot_system_to_assistant = False
|
||||
config = GithubCopilotConfig()
|
||||
messages = [
|
||||
{"role": "system", "content": "System message."},
|
||||
{"role": "user", "content": "User message."},
|
||||
]
|
||||
out = config._transform_messages([m.copy() for m in messages], model="github_copilot/gpt-4")
|
||||
assert out[0]["role"] == "assistant"
|
||||
assert out[1]["role"] == "user"
|
||||
|
||||
# Case 2: Flag is True (conversion does not happen)
|
||||
litellm.disable_copilot_system_to_assistant = True
|
||||
out = config._transform_messages([m.copy() for m in messages], model="github_copilot/gpt-4")
|
||||
assert out[0]["role"] == "system"
|
||||
assert out[1]["role"] == "user"
|
||||
|
||||
# Case 3: Flag is False again (conversion happens)
|
||||
litellm.disable_copilot_system_to_assistant = False
|
||||
out = config._transform_messages([m.copy() for m in messages], model="github_copilot/gpt-4")
|
||||
assert out[0]["role"] == "assistant"
|
||||
assert out[1]["role"] == "user"
|
||||
finally:
|
||||
# Restore original value
|
||||
litellm.disable_copilot_system_to_assistant = original_flag
|
||||
|
||||
|
||||
def test_x_initiator_header_user_request():
|
||||
"""Test that user-only messages result in X-Initiator: user header"""
|
||||
config = GithubCopilotConfig()
|
||||
|
||||
# Mock the authenticator
|
||||
config.authenticator = MagicMock()
|
||||
config.authenticator.get_api_key.return_value = "gh.test-key-123"
|
||||
config.authenticator.get_api_base.return_value = None
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
]
|
||||
|
||||
headers = config.validate_environment(
|
||||
headers={},
|
||||
model="github_copilot/gpt-4",
|
||||
messages=messages,
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
assert headers["X-Initiator"] == "user"
|
||||
|
||||
|
||||
def test_x_initiator_header_agent_request_with_assistant():
|
||||
"""Test that messages with assistant role result in X-Initiator: agent header"""
|
||||
config = GithubCopilotConfig()
|
||||
|
||||
# Mock the authenticator
|
||||
config.authenticator = MagicMock()
|
||||
config.authenticator.get_api_key.return_value = "gh.test-key-123"
|
||||
config.authenticator.get_api_base.return_value = None
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an assistant."},
|
||||
{"role": "assistant", "content": "I can help you."},
|
||||
]
|
||||
|
||||
headers = config.validate_environment(
|
||||
headers={},
|
||||
model="github_copilot/gpt-4",
|
||||
messages=messages,
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
assert headers["X-Initiator"] == "agent"
|
||||
|
||||
|
||||
def test_x_initiator_header_agent_request_with_tool():
|
||||
"""Test that messages with tool role result in X-Initiator: agent header"""
|
||||
config = GithubCopilotConfig()
|
||||
|
||||
# Mock the authenticator
|
||||
config.authenticator = MagicMock()
|
||||
config.authenticator.get_api_key.return_value = "gh.test-key-123"
|
||||
config.authenticator.get_api_base.return_value = None
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an assistant."},
|
||||
{"role": "tool", "content": "Tool response.", "tool_call_id": "123"},
|
||||
]
|
||||
|
||||
headers = config.validate_environment(
|
||||
headers={},
|
||||
model="github_copilot/gpt-4",
|
||||
messages=messages,
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
assert headers["X-Initiator"] == "agent"
|
||||
|
||||
|
||||
def test_x_initiator_header_mixed_messages_with_agent_roles():
|
||||
"""Test that mixed messages with agent roles (assistant/tool) result in X-Initiator: agent header"""
|
||||
config = GithubCopilotConfig()
|
||||
|
||||
# Mock the authenticator
|
||||
config.authenticator = MagicMock()
|
||||
config.authenticator.get_api_key.return_value = "gh.test-key-123"
|
||||
config.authenticator.get_api_base.return_value = None
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Previous response."},
|
||||
{"role": "user", "content": "Follow up question."},
|
||||
]
|
||||
|
||||
headers = config.validate_environment(
|
||||
headers={},
|
||||
model="github_copilot/gpt-4",
|
||||
messages=messages,
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
assert headers["X-Initiator"] == "agent"
|
||||
|
||||
|
||||
def test_x_initiator_header_user_only_messages():
|
||||
"""Test that user + system only messages result in X-Initiator: user header"""
|
||||
config = GithubCopilotConfig()
|
||||
|
||||
# Mock the authenticator
|
||||
config.authenticator = MagicMock()
|
||||
config.authenticator.get_api_key.return_value = "gh.test-key-123"
|
||||
config.authenticator.get_api_base.return_value = None
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "user", "content": "Follow up question."},
|
||||
]
|
||||
|
||||
headers = config.validate_environment(
|
||||
headers={},
|
||||
model="github_copilot/gpt-4",
|
||||
messages=messages,
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
assert headers["X-Initiator"] == "user"
|
||||
|
||||
|
||||
def test_x_initiator_header_empty_messages():
|
||||
"""Test that empty messages result in X-Initiator: user header"""
|
||||
config = GithubCopilotConfig()
|
||||
|
||||
# Mock the authenticator
|
||||
config.authenticator = MagicMock()
|
||||
config.authenticator.get_api_key.return_value = "gh.test-key-123"
|
||||
config.authenticator.get_api_base.return_value = None
|
||||
|
||||
messages = []
|
||||
|
||||
headers = config.validate_environment(
|
||||
headers={},
|
||||
model="github_copilot/gpt-4",
|
||||
messages=messages,
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
assert headers["X-Initiator"] == "user"
|
||||
|
||||
|
||||
def test_x_initiator_header_system_only_messages():
|
||||
"""Test that system-only messages result in X-Initiator: user header"""
|
||||
config = GithubCopilotConfig()
|
||||
|
||||
# Mock the authenticator
|
||||
config.authenticator = MagicMock()
|
||||
config.authenticator.get_api_key.return_value = "gh.test-key-123"
|
||||
config.authenticator.get_api_base.return_value = None
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an assistant."},
|
||||
]
|
||||
|
||||
headers = config.validate_environment(
|
||||
headers={},
|
||||
model="github_copilot/gpt-4",
|
||||
messages=messages,
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
assert headers["X-Initiator"] == "user"
|
Reference in New Issue
Block a user