Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Tests for the LiteLLM Proxy Client CLI package."""
|
@@ -0,0 +1,437 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, Mock, mock_open, patch
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from litellm.proxy.client.cli.commands.auth import (
|
||||
clear_token,
|
||||
get_stored_api_key,
|
||||
get_token_file_path,
|
||||
load_token,
|
||||
login,
|
||||
logout,
|
||||
save_token,
|
||||
whoami,
|
||||
)
|
||||
|
||||
|
||||
class TestTokenUtilities:
|
||||
"""Test token file utility functions"""
|
||||
|
||||
def test_get_token_file_path(self):
|
||||
"""Test getting token file path"""
|
||||
with patch('pathlib.Path.home') as mock_home, \
|
||||
patch('pathlib.Path.mkdir') as mock_mkdir:
|
||||
mock_home.return_value = Path('/home/user')
|
||||
|
||||
result = get_token_file_path()
|
||||
|
||||
assert result == '/home/user/.litellm/token.json'
|
||||
mock_mkdir.assert_called_once_with(exist_ok=True)
|
||||
|
||||
def test_get_token_file_path_creates_directory(self):
|
||||
"""Test that get_token_file_path creates the config directory"""
|
||||
with patch('pathlib.Path.home') as mock_home, \
|
||||
patch('pathlib.Path.mkdir') as mock_mkdir:
|
||||
mock_home.return_value = Path('/home/user')
|
||||
|
||||
get_token_file_path()
|
||||
|
||||
mock_mkdir.assert_called_once_with(exist_ok=True)
|
||||
|
||||
def test_save_token(self):
|
||||
"""Test saving token data to file"""
|
||||
token_data = {
|
||||
'key': 'test-key',
|
||||
'user_id': 'test-user',
|
||||
'timestamp': 1234567890
|
||||
}
|
||||
|
||||
with patch('builtins.open', mock_open()) as mock_file, \
|
||||
patch('litellm.proxy.client.cli.commands.auth.get_token_file_path') as mock_path, \
|
||||
patch('os.chmod') as mock_chmod:
|
||||
|
||||
mock_path.return_value = '/test/path/token.json'
|
||||
|
||||
save_token(token_data)
|
||||
|
||||
mock_file.assert_called_once_with('/test/path/token.json', 'w')
|
||||
mock_file().write.assert_called()
|
||||
mock_chmod.assert_called_once_with('/test/path/token.json', 0o600)
|
||||
|
||||
# Verify JSON content was written correctly
|
||||
written_content = ''.join(call[0][0] for call in mock_file().write.call_args_list)
|
||||
parsed_content = json.loads(written_content)
|
||||
assert parsed_content == token_data
|
||||
|
||||
def test_load_token_success(self):
|
||||
"""Test loading token data from file successfully"""
|
||||
token_data = {
|
||||
'key': 'test-key',
|
||||
'user_id': 'test-user',
|
||||
'timestamp': 1234567890
|
||||
}
|
||||
|
||||
with patch('builtins.open', mock_open(read_data=json.dumps(token_data))), \
|
||||
patch('litellm.proxy.client.cli.commands.auth.get_token_file_path') as mock_path, \
|
||||
patch('os.path.exists', return_value=True):
|
||||
|
||||
mock_path.return_value = '/test/path/token.json'
|
||||
|
||||
result = load_token()
|
||||
|
||||
assert result == token_data
|
||||
|
||||
def test_load_token_file_not_exists(self):
|
||||
"""Test loading token when file doesn't exist"""
|
||||
with patch('litellm.proxy.client.cli.commands.auth.get_token_file_path') as mock_path, \
|
||||
patch('os.path.exists', return_value=False):
|
||||
|
||||
mock_path.return_value = '/test/path/token.json'
|
||||
|
||||
result = load_token()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_load_token_json_decode_error(self):
|
||||
"""Test loading token with invalid JSON"""
|
||||
with patch('builtins.open', mock_open(read_data='invalid json')), \
|
||||
patch('litellm.proxy.client.cli.commands.auth.get_token_file_path') as mock_path, \
|
||||
patch('os.path.exists', return_value=True):
|
||||
|
||||
mock_path.return_value = '/test/path/token.json'
|
||||
|
||||
result = load_token()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_load_token_io_error(self):
|
||||
"""Test loading token with IO error"""
|
||||
with patch('builtins.open', side_effect=IOError("Permission denied")), \
|
||||
patch('litellm.proxy.client.cli.commands.auth.get_token_file_path') as mock_path, \
|
||||
patch('os.path.exists', return_value=True):
|
||||
|
||||
mock_path.return_value = '/test/path/token.json'
|
||||
|
||||
result = load_token()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_clear_token_file_exists(self):
|
||||
"""Test clearing token when file exists"""
|
||||
with patch('litellm.proxy.client.cli.commands.auth.get_token_file_path') as mock_path, \
|
||||
patch('os.path.exists', return_value=True), \
|
||||
patch('os.remove') as mock_remove:
|
||||
|
||||
mock_path.return_value = '/test/path/token.json'
|
||||
|
||||
clear_token()
|
||||
|
||||
mock_remove.assert_called_once_with('/test/path/token.json')
|
||||
|
||||
def test_clear_token_file_not_exists(self):
|
||||
"""Test clearing token when file doesn't exist"""
|
||||
with patch('litellm.proxy.client.cli.commands.auth.get_token_file_path') as mock_path, \
|
||||
patch('os.path.exists', return_value=False), \
|
||||
patch('os.remove') as mock_remove:
|
||||
|
||||
mock_path.return_value = '/test/path/token.json'
|
||||
|
||||
clear_token()
|
||||
|
||||
mock_remove.assert_not_called()
|
||||
|
||||
def test_get_stored_api_key_success(self):
|
||||
"""Test getting stored API key successfully"""
|
||||
token_data = {
|
||||
'key': 'test-api-key-123',
|
||||
'user_id': 'test-user'
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.client.cli.commands.auth.load_token', return_value=token_data):
|
||||
result = get_stored_api_key()
|
||||
assert result == 'test-api-key-123'
|
||||
|
||||
def test_get_stored_api_key_no_token(self):
|
||||
"""Test getting stored API key when no token exists"""
|
||||
with patch('litellm.proxy.client.cli.commands.auth.load_token', return_value=None):
|
||||
result = get_stored_api_key()
|
||||
assert result is None
|
||||
|
||||
def test_get_stored_api_key_no_key_field(self):
|
||||
"""Test getting stored API key when token has no key field"""
|
||||
token_data = {
|
||||
'user_id': 'test-user'
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.client.cli.commands.auth.load_token', return_value=token_data):
|
||||
result = get_stored_api_key()
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestLoginCommand:
|
||||
"""Test login CLI command"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup for each test"""
|
||||
self.runner = CliRunner()
|
||||
|
||||
def test_login_success(self):
|
||||
"""Test successful login flow"""
|
||||
mock_context = Mock()
|
||||
mock_context.obj = {"base_url": "https://test.example.com"}
|
||||
|
||||
# Mock the requests for successful authentication
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"status": "ready",
|
||||
"key": "sk-test-api-key-123"
|
||||
}
|
||||
|
||||
with patch('webbrowser.open') as mock_browser, \
|
||||
patch('requests.get', return_value=mock_response) as mock_get, \
|
||||
patch('litellm.proxy.client.cli.commands.auth.save_token') as mock_save, \
|
||||
patch('litellm.proxy.client.cli.interface.show_commands') as mock_show_commands, \
|
||||
patch('uuid.uuid4', return_value='test-uuid-123'):
|
||||
|
||||
result = self.runner.invoke(login, obj=mock_context.obj)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "✅ Login successful!" in result.output
|
||||
assert "API Key: sk-test-api-key-123" in result.output
|
||||
|
||||
# Verify browser was opened with correct URL
|
||||
mock_browser.assert_called_once()
|
||||
call_args = mock_browser.call_args[0][0]
|
||||
assert "https://test.example.com/sso/key/generate" in call_args
|
||||
assert "sk-test-uuid-123" in call_args
|
||||
|
||||
# Verify token was saved
|
||||
mock_save.assert_called_once()
|
||||
saved_data = mock_save.call_args[0][0]
|
||||
assert saved_data['key'] == 'sk-test-api-key-123'
|
||||
assert saved_data['user_id'] == 'cli-user'
|
||||
|
||||
# Verify commands were shown
|
||||
mock_show_commands.assert_called_once()
|
||||
|
||||
def test_login_timeout(self):
|
||||
"""Test login timeout scenario"""
|
||||
mock_context = Mock()
|
||||
mock_context.obj = {"base_url": "https://test.example.com"}
|
||||
|
||||
# Mock response that never returns ready status
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"status": "pending"}
|
||||
|
||||
with patch('webbrowser.open'), \
|
||||
patch('requests.get', return_value=mock_response), \
|
||||
patch('time.sleep') as mock_sleep, \
|
||||
patch('uuid.uuid4', return_value='test-uuid-123'):
|
||||
|
||||
# Mock time.sleep to avoid actual delays in tests
|
||||
result = self.runner.invoke(login, obj=mock_context.obj)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "❌ Authentication timed out" in result.output
|
||||
|
||||
def test_login_http_error(self):
|
||||
"""Test login with HTTP error"""
|
||||
mock_context = Mock()
|
||||
mock_context.obj = {"base_url": "https://test.example.com"}
|
||||
|
||||
# Mock response with HTTP error
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 500
|
||||
|
||||
with patch('webbrowser.open'), \
|
||||
patch('requests.get', return_value=mock_response), \
|
||||
patch('time.sleep'), \
|
||||
patch('uuid.uuid4', return_value='test-uuid-123'):
|
||||
|
||||
result = self.runner.invoke(login, obj=mock_context.obj)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "❌ Authentication timed out" in result.output
|
||||
|
||||
def test_login_request_exception(self):
|
||||
"""Test login with request exception"""
|
||||
import requests
|
||||
mock_context = Mock()
|
||||
mock_context.obj = {"base_url": "https://test.example.com"}
|
||||
|
||||
with patch('webbrowser.open'), \
|
||||
patch('requests.get', side_effect=requests.RequestException("Connection failed")), \
|
||||
patch('time.sleep'), \
|
||||
patch('uuid.uuid4', return_value='test-uuid-123'):
|
||||
|
||||
result = self.runner.invoke(login, obj=mock_context.obj)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "❌ Authentication timed out" in result.output
|
||||
|
||||
def test_login_keyboard_interrupt(self):
|
||||
"""Test login cancelled by user"""
|
||||
mock_context = Mock()
|
||||
mock_context.obj = {"base_url": "https://test.example.com"}
|
||||
|
||||
with patch('webbrowser.open'), \
|
||||
patch('requests.get', side_effect=KeyboardInterrupt), \
|
||||
patch('uuid.uuid4', return_value='test-uuid-123'):
|
||||
|
||||
result = self.runner.invoke(login, obj=mock_context.obj)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "❌ Authentication cancelled by user" in result.output
|
||||
|
||||
def test_login_no_api_key_in_response(self):
|
||||
"""Test login when response doesn't contain API key"""
|
||||
mock_context = Mock()
|
||||
mock_context.obj = {"base_url": "https://test.example.com"}
|
||||
|
||||
# Mock response without API key
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"status": "ready"
|
||||
# Missing 'key' field
|
||||
}
|
||||
|
||||
with patch('webbrowser.open'), \
|
||||
patch('requests.get', return_value=mock_response), \
|
||||
patch('time.sleep'), \
|
||||
patch('uuid.uuid4', return_value='test-uuid-123'):
|
||||
|
||||
result = self.runner.invoke(login, obj=mock_context.obj)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "❌ Authentication timed out" in result.output
|
||||
|
||||
def test_login_general_exception(self):
|
||||
"""Test login with general exception (not requests exception)"""
|
||||
mock_context = Mock()
|
||||
mock_context.obj = {"base_url": "https://test.example.com"}
|
||||
|
||||
with patch('webbrowser.open'), \
|
||||
patch('requests.get', side_effect=ValueError("Invalid value")), \
|
||||
patch('uuid.uuid4', return_value='test-uuid-123'):
|
||||
|
||||
result = self.runner.invoke(login, obj=mock_context.obj)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "❌ Authentication failed: Invalid value" in result.output
|
||||
|
||||
|
||||
class TestLogoutCommand:
|
||||
"""Test logout CLI command"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup for each test"""
|
||||
self.runner = CliRunner()
|
||||
|
||||
def test_logout_success(self):
|
||||
"""Test successful logout"""
|
||||
with patch('litellm.proxy.client.cli.commands.auth.clear_token') as mock_clear:
|
||||
result = self.runner.invoke(logout)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "✅ Logged out successfully" in result.output
|
||||
mock_clear.assert_called_once()
|
||||
|
||||
|
||||
class TestWhoamiCommand:
|
||||
"""Test whoami CLI command"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup for each test"""
|
||||
self.runner = CliRunner()
|
||||
|
||||
def test_whoami_authenticated(self):
|
||||
"""Test whoami when user is authenticated"""
|
||||
token_data = {
|
||||
'user_email': 'test@example.com',
|
||||
'user_id': 'test-user-123',
|
||||
'user_role': 'admin',
|
||||
'timestamp': time.time() - 3600 # 1 hour ago
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.client.cli.commands.auth.load_token', return_value=token_data):
|
||||
result = self.runner.invoke(whoami)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "✅ Authenticated" in result.output
|
||||
assert "test@example.com" in result.output
|
||||
assert "test-user-123" in result.output
|
||||
assert "admin" in result.output
|
||||
assert "Token age: 1.0 hours" in result.output
|
||||
|
||||
def test_whoami_not_authenticated(self):
|
||||
"""Test whoami when user is not authenticated"""
|
||||
with patch('litellm.proxy.client.cli.commands.auth.load_token', return_value=None):
|
||||
result = self.runner.invoke(whoami)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "❌ Not authenticated" in result.output
|
||||
assert "Run 'litellm-proxy login'" in result.output
|
||||
|
||||
def test_whoami_old_token(self):
|
||||
"""Test whoami with old token showing warning"""
|
||||
token_data = {
|
||||
'user_email': 'test@example.com',
|
||||
'user_id': 'test-user-123',
|
||||
'user_role': 'admin',
|
||||
'timestamp': time.time() - (25 * 3600) # 25 hours ago
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.client.cli.commands.auth.load_token', return_value=token_data):
|
||||
result = self.runner.invoke(whoami)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "✅ Authenticated" in result.output
|
||||
assert "⚠️ Warning: Token is more than 24 hours old" in result.output
|
||||
|
||||
def test_whoami_missing_fields(self):
|
||||
"""Test whoami with token missing some fields"""
|
||||
token_data = {
|
||||
'timestamp': time.time() - 3600
|
||||
# Missing user_email, user_id, user_role
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.client.cli.commands.auth.load_token', return_value=token_data):
|
||||
result = self.runner.invoke(whoami)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "✅ Authenticated" in result.output
|
||||
assert "Unknown" in result.output # Should show "Unknown" for missing fields
|
||||
|
||||
def test_whoami_no_timestamp(self):
|
||||
"""Test whoami with token missing timestamp"""
|
||||
token_data = {
|
||||
'user_email': 'test@example.com',
|
||||
'user_id': 'test-user-123',
|
||||
'user_role': 'admin'
|
||||
# Missing timestamp
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.client.cli.commands.auth.load_token', return_value=token_data), \
|
||||
patch('time.time', return_value=1000):
|
||||
|
||||
result = self.runner.invoke(whoami)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "✅ Authenticated" in result.output
|
||||
# Should calculate age based on timestamp=0
|
||||
assert "Token age:" in result.output
|
@@ -0,0 +1,248 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from click.testing import CliRunner
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
from litellm.proxy.client.cli.main import cli
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chat_client():
|
||||
with patch("litellm.proxy.client.cli.commands.chat.ChatClient") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cli_runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
def test_chat_completions_success(cli_runner, mock_chat_client):
|
||||
# Mock response data
|
||||
mock_response = {
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677858242,
|
||||
"model": "gpt-4",
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you today?",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
}
|
||||
mock_instance = mock_chat_client.return_value
|
||||
mock_instance.completions.return_value = mock_response
|
||||
|
||||
# Run command
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"chat",
|
||||
"completions",
|
||||
"gpt-4",
|
||||
"-m",
|
||||
"user:Hello!",
|
||||
"--temperature",
|
||||
"0.7",
|
||||
"--max-tokens",
|
||||
"100",
|
||||
],
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result.exit_code == 0
|
||||
output_data = json.loads(result.output)
|
||||
assert output_data == mock_response
|
||||
mock_instance.completions.assert_called_once_with(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
top_p=None,
|
||||
n=None,
|
||||
presence_penalty=None,
|
||||
frequency_penalty=None,
|
||||
user=None,
|
||||
)
|
||||
|
||||
|
||||
def test_chat_completions_multiple_messages(cli_runner, mock_chat_client):
|
||||
# Mock response data
|
||||
mock_response = {
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677858242,
|
||||
"model": "gpt-4",
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Paris has a population of about 2.2 million.",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
}
|
||||
mock_instance = mock_chat_client.return_value
|
||||
mock_instance.completions.return_value = mock_response
|
||||
|
||||
# Run command
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"chat",
|
||||
"completions",
|
||||
"gpt-4",
|
||||
"-m",
|
||||
"system:You are a helpful assistant",
|
||||
"-m",
|
||||
"user:What's the population of Paris?",
|
||||
],
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result.exit_code == 0
|
||||
output_data = json.loads(result.output)
|
||||
assert output_data == mock_response
|
||||
mock_instance.completions.assert_called_once_with(
|
||||
model="gpt-4",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "What's the population of Paris?"},
|
||||
],
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
top_p=None,
|
||||
n=None,
|
||||
presence_penalty=None,
|
||||
frequency_penalty=None,
|
||||
user=None,
|
||||
)
|
||||
|
||||
|
||||
def test_chat_completions_no_messages(cli_runner, mock_chat_client):
|
||||
# Run command without any messages
|
||||
result = cli_runner.invoke(cli, ["chat", "completions", "gpt-4"])
|
||||
|
||||
# Verify
|
||||
assert result.exit_code == 2
|
||||
assert "At least one message is required" in result.output
|
||||
mock_instance = mock_chat_client.return_value
|
||||
mock_instance.completions.assert_not_called()
|
||||
|
||||
|
||||
def test_chat_completions_invalid_message_format(cli_runner, mock_chat_client):
|
||||
# Run command with invalid message format
|
||||
result = cli_runner.invoke(
|
||||
cli, ["chat", "completions", "gpt-4", "-m", "invalid-format"]
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result.exit_code == 2
|
||||
assert "Invalid message format" in result.output
|
||||
mock_instance = mock_chat_client.return_value
|
||||
mock_instance.completions.assert_not_called()
|
||||
|
||||
|
||||
def test_chat_completions_http_error(cli_runner, mock_chat_client):
|
||||
# Mock HTTP error
|
||||
mock_instance = mock_chat_client.return_value
|
||||
mock_error_response = MagicMock()
|
||||
mock_error_response.status_code = 400
|
||||
mock_error_response.json.return_value = {
|
||||
"error": "Invalid request",
|
||||
"message": "Invalid model specified",
|
||||
}
|
||||
mock_instance.completions.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_error_response
|
||||
)
|
||||
|
||||
# Run command
|
||||
result = cli_runner.invoke(
|
||||
cli, ["chat", "completions", "invalid-model", "-m", "user:Hello"]
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result.exit_code == 1
|
||||
assert "Error: HTTP 400" in result.output
|
||||
assert "Invalid request" in result.output
|
||||
assert "Invalid model specified" in result.output
|
||||
|
||||
|
||||
def test_chat_completions_all_parameters(cli_runner, mock_chat_client):
|
||||
# Mock response data
|
||||
mock_response = {
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677858242,
|
||||
"model": "gpt-4",
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Response with all parameters set",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
}
|
||||
mock_instance = mock_chat_client.return_value
|
||||
mock_instance.completions.return_value = mock_response
|
||||
|
||||
# Run command with all available parameters
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"chat",
|
||||
"completions",
|
||||
"gpt-4",
|
||||
"-m",
|
||||
"user:Test message",
|
||||
"--temperature",
|
||||
"0.7",
|
||||
"--top-p",
|
||||
"0.9",
|
||||
"--n",
|
||||
"1",
|
||||
"--max-tokens",
|
||||
"100",
|
||||
"--presence-penalty",
|
||||
"0.5",
|
||||
"--frequency-penalty",
|
||||
"0.5",
|
||||
"--user",
|
||||
"test-user",
|
||||
],
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result.exit_code == 0
|
||||
output_data = json.loads(result.output)
|
||||
assert output_data == mock_response
|
||||
mock_instance.completions.assert_called_once_with(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Test message"}],
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
n=1,
|
||||
max_tokens=100,
|
||||
presence_penalty=0.5,
|
||||
frequency_penalty=0.5,
|
||||
user="test-user",
|
||||
)
|
@@ -0,0 +1,222 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from click.testing import CliRunner
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
from litellm.proxy.client.cli.main import cli
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credentials_client(monkeypatch):
|
||||
"""Patch the CredentialsManagementClient used by the CLI commands."""
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.client.credentials.CredentialsManagementClient",
|
||||
mock_client,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.client.cli.commands.credentials.CredentialsManagementClient",
|
||||
mock_client,
|
||||
)
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cli_runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
def test_alist_credentials_table_format(cli_runner, mock_credentials_client):
|
||||
# Mock response data
|
||||
mock_response = {
|
||||
"credentials": [
|
||||
{
|
||||
"credential_name": "test-cred-1",
|
||||
"credential_info": {"custom_llm_provider": "azure"},
|
||||
},
|
||||
{
|
||||
"credential_name": "test-cred-2",
|
||||
"credential_info": {"custom_llm_provider": "anthropic"},
|
||||
},
|
||||
]
|
||||
}
|
||||
mock_instance = mock_credentials_client.return_value
|
||||
mock_instance.list.return_value = mock_response
|
||||
|
||||
# Run command
|
||||
result = cli_runner.invoke(cli, ["credentials", "list"])
|
||||
|
||||
# Verify
|
||||
assert result.exit_code == 0
|
||||
assert "test-cred-1" in result.output
|
||||
assert "azure" in result.output
|
||||
assert "test-cred-2" in result.output
|
||||
assert "anthropic" in result.output
|
||||
|
||||
|
||||
def test_alist_credentials_json_format(cli_runner, mock_credentials_client):
|
||||
# Mock response data
|
||||
mock_response = {
|
||||
"credentials": [
|
||||
{
|
||||
"credential_name": "test-cred",
|
||||
"credential_info": {"custom_llm_provider": "azure"},
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_instance = mock_credentials_client.return_value
|
||||
mock_instance.list.return_value = mock_response
|
||||
|
||||
# Run command
|
||||
result = cli_runner.invoke(cli, ["credentials", "list", "--format", "json"])
|
||||
|
||||
# Verify
|
||||
assert result.exit_code == 0
|
||||
output_data = json.loads(result.output)
|
||||
assert output_data == mock_response
|
||||
|
||||
|
||||
def test_acreate_credential_success(cli_runner, mock_credentials_client):
|
||||
# Mock response data
|
||||
mock_response = {"status": "success", "credential_name": "test-cred"}
|
||||
mock_instance = mock_credentials_client.return_value
|
||||
mock_instance.create.return_value = mock_response
|
||||
|
||||
# Run command
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"credentials",
|
||||
"create",
|
||||
"test-cred",
|
||||
"--info",
|
||||
'{"custom_llm_provider": "azure"}',
|
||||
"--values",
|
||||
'{"api_key": "test-key"}',
|
||||
],
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result.exit_code == 0
|
||||
output_data = json.loads(result.output)
|
||||
assert output_data == mock_response
|
||||
mock_instance.create.assert_called_once_with(
|
||||
"test-cred",
|
||||
{"custom_llm_provider": "azure"},
|
||||
{"api_key": "test-key"},
|
||||
)
|
||||
|
||||
|
||||
def test_acreate_credential_invalid_json(cli_runner, mock_credentials_client):
|
||||
# Run command with invalid JSON
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"credentials",
|
||||
"create",
|
||||
"test-cred",
|
||||
"--info",
|
||||
"invalid-json",
|
||||
"--values",
|
||||
'{"api_key": "test-key"}',
|
||||
],
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result.exit_code == 2
|
||||
assert "Invalid JSON" in result.output
|
||||
mock_instance = mock_credentials_client.return_value
|
||||
mock_instance.create.assert_not_called()
|
||||
|
||||
|
||||
def test_acreate_credential_http_error(cli_runner, mock_credentials_client):
|
||||
# Mock HTTP error
|
||||
mock_instance = mock_credentials_client.return_value
|
||||
mock_error_response = MagicMock()
|
||||
mock_error_response.status_code = 400
|
||||
mock_error_response.json.return_value = {"error": "Invalid request"}
|
||||
mock_instance.create.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_error_response
|
||||
)
|
||||
|
||||
# Run command
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"credentials",
|
||||
"create",
|
||||
"test-cred",
|
||||
"--info",
|
||||
'{"custom_llm_provider": "azure"}',
|
||||
"--values",
|
||||
'{"api_key": "test-key"}',
|
||||
],
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result.exit_code == 1
|
||||
assert "Error: HTTP 400" in result.output
|
||||
assert "Invalid request" in result.output
|
||||
|
||||
|
||||
def test_adelete_credential_success(cli_runner, mock_credentials_client):
|
||||
# Mock response data
|
||||
mock_response = {"status": "success", "message": "Credential deleted"}
|
||||
mock_instance = mock_credentials_client.return_value
|
||||
mock_instance.delete.return_value = mock_response
|
||||
|
||||
# Run command
|
||||
result = cli_runner.invoke(cli, ["credentials", "delete", "test-cred"])
|
||||
|
||||
# Verify
|
||||
assert result.exit_code == 0
|
||||
output_data = json.loads(result.output)
|
||||
assert output_data == mock_response
|
||||
mock_instance.delete.assert_called_once_with("test-cred")
|
||||
|
||||
|
||||
def test_adelete_credential_http_error(cli_runner, mock_credentials_client):
|
||||
# Mock HTTP error
|
||||
mock_instance = mock_credentials_client.return_value
|
||||
mock_error_response = MagicMock()
|
||||
mock_error_response.status_code = 404
|
||||
mock_error_response.json.return_value = {"error": "Credential not found"}
|
||||
mock_instance.delete.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_error_response
|
||||
)
|
||||
|
||||
# Run command
|
||||
result = cli_runner.invoke(cli, ["credentials", "delete", "test-cred"])
|
||||
|
||||
# Verify
|
||||
assert result.exit_code == 1
|
||||
assert "Error: HTTP 404" in result.output
|
||||
assert "Credential not found" in result.output
|
||||
|
||||
|
||||
def test_aget_credential_success(cli_runner, mock_credentials_client):
|
||||
# Mock response data
|
||||
mock_response = {
|
||||
"credential_name": "test-cred",
|
||||
"credential_info": {"custom_llm_provider": "azure"},
|
||||
}
|
||||
mock_instance = mock_credentials_client.return_value
|
||||
mock_instance.get.return_value = mock_response
|
||||
|
||||
# Run command
|
||||
result = cli_runner.invoke(cli, ["credentials", "get", "test-cred"])
|
||||
|
||||
# Verify
|
||||
assert result.exit_code == 0
|
||||
output_data = json.loads(result.output)
|
||||
assert output_data == mock_response
|
||||
mock_instance.get.assert_called_once_with("test-cred")
|
@@ -0,0 +1,46 @@
|
||||
# stdlib imports
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
from litellm._version import version as litellm_version
|
||||
from litellm.proxy.client.cli import cli
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cli_runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
def test_cli_version_flag(cli_runner):
|
||||
"""Test that --version prints the correct version, server URL, and server version, and exits successfully"""
|
||||
with patch(
|
||||
"litellm.proxy.client.health.HealthManagementClient.get_server_version",
|
||||
return_value="1.2.3",
|
||||
), patch.dict(os.environ, {"LITELLM_PROXY_URL": "http://localhost:4000"}):
|
||||
result = cli_runner.invoke(cli, ["--version"])
|
||||
assert result.exit_code == 0
|
||||
assert f"LiteLLM Proxy CLI Version: {litellm_version}" in result.output
|
||||
assert "LiteLLM Proxy Server URL: http://localhost:4000" in result.output
|
||||
assert "LiteLLM Proxy Server Version: 1.2.3" in result.output
|
||||
|
||||
|
||||
def test_cli_version_command(cli_runner):
|
||||
"""Test that 'version' command prints the correct version, server URL, and server version, and exits successfully"""
|
||||
with patch(
|
||||
"litellm.proxy.client.health.HealthManagementClient.get_server_version",
|
||||
return_value="1.2.3",
|
||||
), patch.dict(os.environ, {"LITELLM_PROXY_URL": "http://localhost:4000"}):
|
||||
result = cli_runner.invoke(cli, ["version"])
|
||||
assert result.exit_code == 0
|
||||
assert f"LiteLLM Proxy CLI Version: {litellm_version}" in result.output
|
||||
assert "LiteLLM Proxy Server URL: http://localhost:4000" in result.output
|
||||
assert "LiteLLM Proxy Server Version: 1.2.3" in result.output
|
@@ -0,0 +1,481 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from litellm.proxy.client.cli import cli
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cli_runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_env():
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"LITELLM_PROXY_URL": "http://localhost:4000",
|
||||
"LITELLM_PROXY_API_KEY": "sk-test",
|
||||
},
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_keys_client():
|
||||
with patch("litellm.proxy.client.cli.commands.keys.KeysManagementClient") as MockClient:
|
||||
yield MockClient
|
||||
|
||||
|
||||
def test_async_keys_list_json_format(mock_keys_client, cli_runner):
|
||||
mock_keys_client.return_value.list.return_value = {
|
||||
"keys": [
|
||||
{
|
||||
"token": "abc123",
|
||||
"key_alias": "alias1",
|
||||
"user_id": "u1",
|
||||
"team_id": "t1",
|
||||
"spend": 10.0,
|
||||
}
|
||||
]
|
||||
}
|
||||
result = cli_runner.invoke(cli, ["keys", "list", "--format", "json"])
|
||||
assert result.exit_code == 0
|
||||
output_data = json.loads(result.output)
|
||||
assert output_data == mock_keys_client.return_value.list.return_value
|
||||
mock_keys_client.assert_called_once_with("http://localhost:4000", "sk-test")
|
||||
mock_keys_client.return_value.list.assert_called_once()
|
||||
|
||||
|
||||
def test_async_keys_list_table_format(mock_keys_client, cli_runner):
|
||||
mock_keys_client.return_value.list.return_value = {
|
||||
"keys": [
|
||||
{
|
||||
"token": "abc123",
|
||||
"key_alias": "alias1",
|
||||
"user_id": "u1",
|
||||
"team_id": "t1",
|
||||
"spend": 10.0,
|
||||
}
|
||||
]
|
||||
}
|
||||
result = cli_runner.invoke(cli, ["keys", "list"])
|
||||
assert result.exit_code == 0
|
||||
assert "abc123" in result.output
|
||||
assert "alias1" in result.output
|
||||
assert "u1" in result.output
|
||||
assert "t1" in result.output
|
||||
assert "10.0" in result.output
|
||||
mock_keys_client.assert_called_once_with("http://localhost:4000", "sk-test")
|
||||
mock_keys_client.return_value.list.assert_called_once()
|
||||
|
||||
|
||||
def test_async_keys_generate_success(mock_keys_client, cli_runner):
|
||||
mock_keys_client.return_value.generate.return_value = {
|
||||
"key": "new-key",
|
||||
"spend": 100.0,
|
||||
}
|
||||
result = cli_runner.invoke(cli, ["keys", "generate", "--models", "gpt-4", "--spend", "100"])
|
||||
assert result.exit_code == 0
|
||||
assert "new-key" in result.output
|
||||
mock_keys_client.return_value.generate.assert_called_once()
|
||||
|
||||
|
||||
def test_async_keys_delete_success(mock_keys_client, cli_runner):
|
||||
mock_keys_client.return_value.delete.return_value = {
|
||||
"status": "success",
|
||||
"deleted_keys": ["abc123"],
|
||||
}
|
||||
result = cli_runner.invoke(cli, ["keys", "delete", "--keys", "abc123"])
|
||||
assert result.exit_code == 0
|
||||
assert "success" in result.output
|
||||
assert "abc123" in result.output
|
||||
mock_keys_client.return_value.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_async_keys_list_error_handling(mock_keys_client, cli_runner):
|
||||
mock_keys_client.return_value.list.side_effect = Exception("API Error")
|
||||
result = cli_runner.invoke(cli, ["keys", "list"])
|
||||
assert result.exit_code != 0
|
||||
assert "API Error" in str(result.exception)
|
||||
|
||||
|
||||
def test_async_keys_generate_error_handling(mock_keys_client, cli_runner):
|
||||
mock_keys_client.return_value.generate.side_effect = Exception("API Error")
|
||||
result = cli_runner.invoke(cli, ["keys", "generate", "--models", "gpt-4"])
|
||||
assert result.exit_code != 0
|
||||
assert "API Error" in str(result.exception)
|
||||
|
||||
|
||||
def test_async_keys_delete_error_handling(mock_keys_client, cli_runner):
|
||||
import requests
|
||||
|
||||
# Mock a connection error that would normally happen in CI
|
||||
mock_keys_client.return_value.delete.side_effect = requests.exceptions.ConnectionError(
|
||||
"Connection error"
|
||||
)
|
||||
result = cli_runner.invoke(cli, ["keys", "delete", "--keys", "abc123"])
|
||||
assert result.exit_code != 0
|
||||
# Check that the exception is properly propagated
|
||||
assert result.exception is not None
|
||||
# The ConnectionError should propagate since it's not caught by HTTPError handler
|
||||
# Check for connection-related keywords that appear in both mocked and real errors
|
||||
error_str = str(result.exception).lower()
|
||||
assert any(keyword in error_str for keyword in ["connection", "connect", "refused", "error"])
|
||||
|
||||
|
||||
def test_async_keys_delete_http_error_handling(mock_keys_client, cli_runner):
|
||||
from unittest.mock import Mock
|
||||
|
||||
import requests
|
||||
|
||||
# Create a mock response object for HTTPError
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.json.return_value = {"error": "Bad request"}
|
||||
|
||||
# Mock an HTTPError which should be caught by the delete command
|
||||
http_error = requests.exceptions.HTTPError("HTTP Error")
|
||||
http_error.response = mock_response
|
||||
mock_keys_client.return_value.delete.side_effect = http_error
|
||||
|
||||
result = cli_runner.invoke(cli, ["keys", "delete", "--keys", "abc123"])
|
||||
assert result.exit_code != 0
|
||||
# HTTPError should be caught and converted to click.Abort
|
||||
assert isinstance(result.exception, SystemExit) # click.Abort raises SystemExit
|
||||
|
||||
|
||||
# Tests for keys import command
|
||||
def test_keys_import_dry_run_success(mock_keys_client, cli_runner):
|
||||
"""Test successful dry-run import showing table of keys that would be imported"""
|
||||
# Mock source client response (paginated)
|
||||
mock_source_instance = mock_keys_client.return_value
|
||||
mock_source_instance.list.side_effect = [
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"key_alias": "test-key-1",
|
||||
"user_id": "user1@example.com",
|
||||
"created_at": "2024-01-15T10:30:00Z",
|
||||
"models": ["gpt-4"],
|
||||
"spend": 10.0,
|
||||
},
|
||||
{
|
||||
"key_alias": "test-key-2",
|
||||
"user_id": "user2@example.com",
|
||||
"created_at": "2024-01-16T11:45:00Z",
|
||||
"models": [],
|
||||
"spend": 5.0,
|
||||
}
|
||||
]
|
||||
},
|
||||
{"keys": []} # Empty second page
|
||||
]
|
||||
|
||||
result = cli_runner.invoke(cli, [
|
||||
"keys", "import",
|
||||
"--source-base-url", "https://source.example.com",
|
||||
"--source-api-key", "sk-source-123",
|
||||
"--dry-run"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Found 2 keys in source instance" in result.output
|
||||
assert "DRY RUN MODE" in result.output
|
||||
assert "test-key-1" in result.output
|
||||
assert "user1@example.com" in result.output
|
||||
assert "test-key-2" in result.output
|
||||
assert "user2@example.com" in result.output
|
||||
|
||||
# Verify source client was called (pagination stops early when fewer keys than page_size)
|
||||
assert mock_source_instance.list.call_count >= 1
|
||||
mock_source_instance.list.assert_any_call(return_full_object=True, page=1, size=100)
|
||||
|
||||
|
||||
def test_keys_import_actual_import_success(mock_keys_client, cli_runner):
|
||||
"""Test successful actual import of keys"""
|
||||
# Create separate mock instances for source and destination
|
||||
with patch("litellm.proxy.client.cli.commands.keys.KeysManagementClient") as MockClient:
|
||||
mock_source_instance = MockClient.return_value
|
||||
mock_dest_instance = MockClient.return_value
|
||||
|
||||
# Configure source client
|
||||
mock_source_instance.list.side_effect = [
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"key_alias": "import-key-1",
|
||||
"user_id": "user1@example.com",
|
||||
"models": ["gpt-4"],
|
||||
"spend": 100.0,
|
||||
"team_id": "team-1"
|
||||
}
|
||||
]
|
||||
},
|
||||
{"keys": []} # Empty second page
|
||||
]
|
||||
|
||||
# Configure destination client
|
||||
mock_dest_instance.generate.return_value = {
|
||||
"key": "sk-new-generated-key",
|
||||
"status": "success"
|
||||
}
|
||||
|
||||
result = cli_runner.invoke(cli, [
|
||||
"keys", "import",
|
||||
"--source-base-url", "https://source.example.com",
|
||||
"--source-api-key", "sk-source-123"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Found 1 keys in source instance" in result.output
|
||||
assert "✓ Imported key: import-key-1" in result.output
|
||||
assert "Successfully imported: 1" in result.output
|
||||
assert "Failed to import: 0" in result.output
|
||||
|
||||
# Verify generate was called with correct parameters
|
||||
mock_dest_instance.generate.assert_called_once_with(
|
||||
models=["gpt-4"],
|
||||
spend=100.0,
|
||||
key_alias="import-key-1",
|
||||
team_id="team-1",
|
||||
user_id="user1@example.com"
|
||||
)
|
||||
|
||||
|
||||
def test_keys_import_pagination_handling(mock_keys_client, cli_runner):
|
||||
"""Test that import correctly handles pagination to get all keys"""
|
||||
mock_source_instance = mock_keys_client.return_value
|
||||
mock_source_instance.list.side_effect = [
|
||||
{"keys": [{"key_alias": f"key-{i}", "user_id": f"user{i}@example.com"} for i in range(100)]}, # Page 1: 100 keys
|
||||
{"keys": [{"key_alias": f"key-{i}", "user_id": f"user{i}@example.com"} for i in range(100, 150)]}, # Page 2: 50 keys
|
||||
{"keys": []} # Page 3: Empty
|
||||
]
|
||||
|
||||
result = cli_runner.invoke(cli, [
|
||||
"keys", "import",
|
||||
"--source-base-url", "https://source.example.com",
|
||||
"--dry-run"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Fetched page 1: 100 keys" in result.output
|
||||
assert "Fetched page 2: 50 keys" in result.output
|
||||
assert "Found 150 keys in source instance" in result.output
|
||||
|
||||
# Verify pagination calls (stops early when fewer keys than page_size)
|
||||
assert mock_source_instance.list.call_count >= 2
|
||||
mock_source_instance.list.assert_any_call(return_full_object=True, page=1, size=100)
|
||||
mock_source_instance.list.assert_any_call(return_full_object=True, page=2, size=100)
|
||||
|
||||
|
||||
def test_keys_import_created_since_filter(mock_keys_client, cli_runner):
|
||||
"""Test that --created-since filter works correctly"""
|
||||
mock_source_instance = mock_keys_client.return_value
|
||||
mock_source_instance.list.side_effect = [
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"key_alias": "old-key",
|
||||
"user_id": "user1@example.com",
|
||||
"created_at": "2024-01-01T10:00:00Z" # Before filter
|
||||
},
|
||||
{
|
||||
"key_alias": "new-key",
|
||||
"user_id": "user2@example.com",
|
||||
"created_at": "2024-07-08T10:00:00Z" # After filter
|
||||
}
|
||||
]
|
||||
},
|
||||
{"keys": []}
|
||||
]
|
||||
|
||||
result = cli_runner.invoke(cli, [
|
||||
"keys", "import",
|
||||
"--source-base-url", "https://source.example.com",
|
||||
"--created-since", "2024-07-07_18:19",
|
||||
"--dry-run"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Filtered 2 keys to 1 keys created since 2024-07-07_18:19" in result.output
|
||||
assert "Found 1 keys in source instance" in result.output
|
||||
assert "new-key" in result.output
|
||||
assert "old-key" not in result.output
|
||||
|
||||
|
||||
def test_keys_import_created_since_date_only_format(mock_keys_client, cli_runner):
|
||||
"""Test --created-since with date-only format (YYYY-MM-DD)"""
|
||||
mock_source_instance = mock_keys_client.return_value
|
||||
mock_source_instance.list.side_effect = [
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"key_alias": "test-key",
|
||||
"user_id": "user@example.com",
|
||||
"created_at": "2024-07-08T10:00:00Z"
|
||||
}
|
||||
]
|
||||
},
|
||||
{"keys": []}
|
||||
]
|
||||
|
||||
result = cli_runner.invoke(cli, [
|
||||
"keys", "import",
|
||||
"--source-base-url", "https://source.example.com",
|
||||
"--created-since", "2024-07-07", # Date only format
|
||||
"--dry-run"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Filtered 1 keys to 1 keys created since 2024-07-07" in result.output
|
||||
|
||||
|
||||
def test_keys_import_no_keys_found(mock_keys_client, cli_runner):
|
||||
"""Test handling when no keys are found in source instance"""
|
||||
mock_source_instance = mock_keys_client.return_value
|
||||
mock_source_instance.list.return_value = {"keys": []}
|
||||
|
||||
result = cli_runner.invoke(cli, [
|
||||
"keys", "import",
|
||||
"--source-base-url", "https://source.example.com",
|
||||
"--dry-run"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "No keys found in source instance" in result.output
|
||||
|
||||
|
||||
def test_keys_import_invalid_date_format(cli_runner):
|
||||
"""Test error handling for invalid --created-since date format"""
|
||||
result = cli_runner.invoke(cli, [
|
||||
"keys", "import",
|
||||
"--source-base-url", "https://source.example.com",
|
||||
"--created-since", "invalid-date",
|
||||
"--dry-run"
|
||||
])
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "Invalid date format" in result.output
|
||||
assert "Use YYYY-MM-DD_HH:MM or YYYY-MM-DD" in result.output
|
||||
|
||||
|
||||
def test_keys_import_source_api_error(mock_keys_client, cli_runner):
|
||||
"""Test error handling when source API returns an error"""
|
||||
mock_source_instance = mock_keys_client.return_value
|
||||
mock_source_instance.list.side_effect = Exception("Source API Error")
|
||||
|
||||
result = cli_runner.invoke(cli, [
|
||||
"keys", "import",
|
||||
"--source-base-url", "https://source.example.com",
|
||||
"--dry-run"
|
||||
])
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "Source API Error" in result.output
|
||||
|
||||
|
||||
def test_keys_import_partial_failure(mock_keys_client, cli_runner):
|
||||
"""Test handling when some keys fail to import"""
|
||||
with patch("litellm.proxy.client.cli.commands.keys.KeysManagementClient") as MockClient:
|
||||
mock_source_instance = MockClient.return_value
|
||||
mock_dest_instance = MockClient.return_value
|
||||
|
||||
# Source returns 2 keys
|
||||
mock_source_instance.list.side_effect = [
|
||||
{
|
||||
"keys": [
|
||||
{"key_alias": "success-key", "user_id": "user1@example.com"},
|
||||
{"key_alias": "fail-key", "user_id": "user2@example.com"}
|
||||
]
|
||||
},
|
||||
{"keys": []}
|
||||
]
|
||||
|
||||
# Destination: first succeeds, second fails
|
||||
mock_dest_instance.generate.side_effect = [
|
||||
{"key": "sk-new-key", "status": "success"},
|
||||
Exception("Import failed for this key")
|
||||
]
|
||||
|
||||
result = cli_runner.invoke(cli, [
|
||||
"keys", "import",
|
||||
"--source-base-url", "https://source.example.com"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0 # Command completes even with partial failures
|
||||
assert "✓ Imported key: success-key" in result.output
|
||||
assert "✗ Failed to import key fail-key" in result.output
|
||||
assert "Successfully imported: 1" in result.output
|
||||
assert "Failed to import: 1" in result.output
|
||||
assert "Total keys processed: 2" in result.output
|
||||
|
||||
|
||||
def test_keys_import_missing_required_source_url(cli_runner):
|
||||
"""Test error when required --source-base-url is missing"""
|
||||
result = cli_runner.invoke(cli, [
|
||||
"keys", "import",
|
||||
"--dry-run"
|
||||
])
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "Missing option" in result.output or "required" in result.output.lower()
|
||||
|
||||
|
||||
def test_keys_import_with_all_key_properties(mock_keys_client, cli_runner):
|
||||
"""Test import preserves all key properties (models, aliases, config, etc.)"""
|
||||
with patch("litellm.proxy.client.cli.commands.keys.KeysManagementClient") as MockClient:
|
||||
mock_source_instance = MockClient.return_value
|
||||
mock_dest_instance = MockClient.return_value
|
||||
|
||||
mock_source_instance.list.side_effect = [
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"key_alias": "full-key",
|
||||
"user_id": "user@example.com",
|
||||
"team_id": "team-123",
|
||||
"budget_id": "budget-456",
|
||||
"models": ["gpt-4", "gpt-3.5-turbo"],
|
||||
"aliases": {"custom-model": "gpt-4"},
|
||||
"spend": 50.0,
|
||||
"config": {"max_tokens": 1000}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"keys": []}
|
||||
]
|
||||
|
||||
mock_dest_instance.generate.return_value = {"key": "sk-imported", "status": "success"}
|
||||
|
||||
result = cli_runner.invoke(cli, [
|
||||
"keys", "import",
|
||||
"--source-base-url", "https://source.example.com"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Verify all properties were passed to generate
|
||||
mock_dest_instance.generate.assert_called_once_with(
|
||||
models=["gpt-4", "gpt-3.5-turbo"],
|
||||
aliases={"custom-model": "gpt-4"},
|
||||
spend=50.0,
|
||||
key_alias="full-key",
|
||||
team_id="team-123",
|
||||
user_id="user@example.com",
|
||||
budget_id="budget-456",
|
||||
config={"max_tokens": 1000}
|
||||
)
|
@@ -0,0 +1,433 @@
|
||||
# stdlib imports
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
# third party imports
|
||||
from click.testing import CliRunner
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
# local imports
|
||||
from litellm.proxy.client.cli import cli
|
||||
from litellm.proxy.client.cli.commands.models import (
|
||||
format_cost_per_1k_tokens,
|
||||
format_iso_datetime_str,
|
||||
format_timestamp,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client():
|
||||
"""Fixture to create a mock client with common setup"""
|
||||
with patch("litellm.proxy.client.cli.commands.models.Client") as MockClient:
|
||||
yield MockClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cli_runner():
|
||||
"""Fixture for Click CLI runner"""
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_env():
|
||||
"""Fixture to set up environment variables for all tests"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"LITELLM_PROXY_URL": "http://localhost:4000",
|
||||
"LITELLM_PROXY_API_KEY": "sk-test",
|
||||
},
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_models_list(mock_client):
|
||||
"""Fixture to set up common mocking pattern for models list tests"""
|
||||
mock_client.return_value.models.list.return_value = [
|
||||
{
|
||||
"id": "model-123",
|
||||
"object": "model",
|
||||
"created": 1699848889,
|
||||
"owned_by": "organization-123",
|
||||
},
|
||||
{
|
||||
"id": "model-456",
|
||||
"object": "model",
|
||||
"created": 1699848890,
|
||||
"owned_by": "organization-456",
|
||||
},
|
||||
]
|
||||
|
||||
mock_client.assert_not_called() # Ensure clean slate
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_models_info(mock_client):
|
||||
"""Fixture to set up models info mock"""
|
||||
mock_client.return_value.models.info.return_value = [
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {"model": "gpt-4", "litellm_credential_name": "openai-1"},
|
||||
"model_info": {
|
||||
"id": "model-123",
|
||||
"created_at": "2025-04-29T21:31:43.843000+00:00",
|
||||
"updated_at": "2025-04-29T21:31:43.843000+00:00",
|
||||
"input_cost_per_token": 0.00001,
|
||||
"output_cost_per_token": 0.00002,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_client.assert_not_called()
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def force_utc_tz():
|
||||
"""Fixture to force UTC timezone for tests that depend on system TZ."""
|
||||
old_tz = os.environ.get("TZ")
|
||||
os.environ["TZ"] = "UTC"
|
||||
if hasattr(time, "tzset"):
|
||||
time.tzset()
|
||||
yield
|
||||
# Restore previous TZ
|
||||
if old_tz is not None:
|
||||
os.environ["TZ"] = old_tz
|
||||
else:
|
||||
if "TZ" in os.environ:
|
||||
del os.environ["TZ"]
|
||||
if hasattr(time, "tzset"):
|
||||
time.tzset()
|
||||
|
||||
|
||||
def test_models_list_json_format(mock_models_list, cli_runner):
|
||||
"""Test the models list command with JSON output format"""
|
||||
# Run the command
|
||||
result = cli_runner.invoke(cli, ["models", "list", "--format", "json"])
|
||||
|
||||
# Check that the command succeeded
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Parse the output and verify it matches our mock data
|
||||
output_data = json.loads(result.output)
|
||||
assert output_data == mock_models_list.return_value.models.list.return_value
|
||||
|
||||
# Verify the client was called correctly
|
||||
mock_models_list.assert_called_once_with(
|
||||
base_url="http://localhost:4000", api_key="sk-test"
|
||||
)
|
||||
mock_models_list.return_value.models.list.assert_called_once()
|
||||
|
||||
|
||||
def test_models_list_table_format(mock_models_list, cli_runner):
|
||||
"""Test the models list command with table output format"""
|
||||
# Run the command
|
||||
result = cli_runner.invoke(cli, ["models", "list"])
|
||||
|
||||
# Check that the command succeeded
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Verify the output contains expected table elements
|
||||
assert "ID" in result.output
|
||||
assert "Object" in result.output
|
||||
assert "Created" in result.output
|
||||
assert "Owned By" in result.output
|
||||
assert "model-123" in result.output
|
||||
assert "organization-123" in result.output
|
||||
assert format_timestamp(1699848889) in result.output
|
||||
|
||||
# Verify the client was called correctly
|
||||
mock_models_list.assert_called_once_with(
|
||||
base_url="http://localhost:4000", api_key="sk-test"
|
||||
)
|
||||
mock_models_list.return_value.models.list.assert_called_once()
|
||||
|
||||
|
||||
def test_models_list_with_base_url(mock_models_list, cli_runner):
|
||||
"""Test the models list command with custom base URL overriding env var"""
|
||||
custom_base_url = "http://custom.server:8000"
|
||||
|
||||
# Run the command with custom base URL
|
||||
result = cli_runner.invoke(cli, ["--base-url", custom_base_url, "models", "list"])
|
||||
|
||||
# Check that the command succeeded
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Verify the client was created with the custom base URL (overriding env var)
|
||||
mock_models_list.assert_called_once_with(
|
||||
base_url=custom_base_url,
|
||||
api_key="sk-test", # Should still use env var for API key
|
||||
)
|
||||
|
||||
|
||||
def test_models_list_with_api_key(mock_models_list, cli_runner):
|
||||
"""Test the models list command with API key overriding env var"""
|
||||
custom_api_key = "custom-test-key"
|
||||
|
||||
# Run the command with custom API key
|
||||
result = cli_runner.invoke(cli, ["--api-key", custom_api_key, "models", "list"])
|
||||
|
||||
# Check that the command succeeded
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Verify the client was created with the custom API key (overriding env var)
|
||||
mock_models_list.assert_called_once_with(
|
||||
base_url="http://localhost:4000", # Should still use env var for base URL
|
||||
api_key=custom_api_key,
|
||||
)
|
||||
|
||||
|
||||
def test_models_list_error_handling(mock_client, cli_runner):
|
||||
"""Test error handling in the models list command"""
|
||||
# Configure mock to raise an exception
|
||||
mock_client.return_value.models.list.side_effect = Exception("API Error")
|
||||
|
||||
# Run the command
|
||||
result = cli_runner.invoke(cli, ["models", "list"])
|
||||
|
||||
# Check that the command failed
|
||||
assert result.exit_code != 0
|
||||
assert "API Error" in str(result.exception)
|
||||
|
||||
# Verify the client was created with env var values
|
||||
mock_client.assert_called_once_with(
|
||||
base_url="http://localhost:4000", api_key="sk-test"
|
||||
)
|
||||
|
||||
|
||||
def test_models_info_json_format(mock_models_info, cli_runner):
|
||||
"""Test the models info command with JSON output format"""
|
||||
# Run the command
|
||||
result = cli_runner.invoke(cli, ["models", "info", "--format", "json"])
|
||||
|
||||
# Check that the command succeeded
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Parse the output and verify it matches our mock data
|
||||
output_data = json.loads(result.output)
|
||||
assert output_data == mock_models_info.return_value.models.info.return_value
|
||||
|
||||
# Verify the client was called correctly with env var values
|
||||
mock_models_info.assert_called_once_with(
|
||||
base_url="http://localhost:4000", api_key="sk-test"
|
||||
)
|
||||
mock_models_info.return_value.models.info.assert_called_once()
|
||||
|
||||
|
||||
def test_models_info_table_format(mock_models_info, cli_runner):
|
||||
"""Test the models info command with table output format"""
|
||||
# Run the command with default columns
|
||||
result = cli_runner.invoke(cli, ["models", "info"])
|
||||
|
||||
# Check that the command succeeded
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Verify the output contains expected table elements
|
||||
assert "Public Model" in result.output
|
||||
assert "Upstream Model" in result.output
|
||||
assert "Updated At" in result.output
|
||||
assert "gpt-4" in result.output
|
||||
assert "2025-04-29 21:31" in result.output
|
||||
|
||||
# Verify seconds and microseconds are not shown
|
||||
assert "21:31:43" not in result.output
|
||||
assert "843000" not in result.output
|
||||
|
||||
# Verify the client was called correctly with env var values
|
||||
mock_models_info.assert_called_once_with(
|
||||
base_url="http://localhost:4000", api_key="sk-test"
|
||||
)
|
||||
mock_models_info.return_value.models.info.assert_called_once()
|
||||
|
||||
|
||||
def test_models_import_only_models_matching_regex(tmp_path, mock_client, cli_runner):
|
||||
"""Test the --only-models-matching-regex option for models import command"""
|
||||
# Prepare a YAML file with a mix of models
|
||||
yaml_content = {
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "gpt-4-model",
|
||||
"litellm_params": {"model": "gpt-4"},
|
||||
"model_info": {"id": "id-1"},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-model",
|
||||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||
"model_info": {"id": "id-2"},
|
||||
},
|
||||
{
|
||||
"model_name": "llama2-model",
|
||||
"litellm_params": {"model": "llama2"},
|
||||
"model_info": {"id": "id-3"},
|
||||
},
|
||||
{
|
||||
"model_name": "other-model",
|
||||
"litellm_params": {"model": "other"},
|
||||
"model_info": {"id": "id-4"},
|
||||
},
|
||||
]
|
||||
}
|
||||
import yaml as pyyaml
|
||||
|
||||
yaml_file = tmp_path / "models.yaml"
|
||||
with open(yaml_file, "w") as f:
|
||||
pyyaml.safe_dump(yaml_content, f)
|
||||
|
||||
# Patch client.models.new to track calls
|
||||
mock_new = mock_client.return_value.models.new
|
||||
|
||||
# Only match models containing 'gpt' in their litellm_params.model
|
||||
result = cli_runner.invoke(
|
||||
cli, ["models", "import", str(yaml_file), "--only-models-matching-regex", "gpt"]
|
||||
)
|
||||
|
||||
# Should succeed
|
||||
assert result.exit_code == 0
|
||||
# Only the two gpt models should be imported
|
||||
calls = [call.kwargs["model_params"]["model"] for call in mock_new.call_args_list]
|
||||
assert set(calls) == {"gpt-4", "gpt-3.5-turbo"}
|
||||
# Should not include llama2 or other
|
||||
assert "llama2" not in calls
|
||||
assert "other" not in calls
|
||||
# Output summary should mention the correct providers
|
||||
assert "gpt-4".split("-")[0] in result.output or "gpt" in result.output
|
||||
|
||||
|
||||
def test_models_import_only_access_groups_matching_regex(
|
||||
tmp_path, mock_client, cli_runner
|
||||
):
|
||||
"""Test the --only-access-groups-matching-regex option for models import command"""
|
||||
# Prepare a YAML file with a mix of models
|
||||
yaml_content = {
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "gpt-4-model",
|
||||
"litellm_params": {"model": "gpt-4"},
|
||||
"model_info": {
|
||||
"id": "id-1",
|
||||
"access_groups": ["beta-models", "prod-models"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-model",
|
||||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||
"model_info": {"id": "id-2", "access_groups": ["alpha-models"]},
|
||||
},
|
||||
{
|
||||
"model_name": "llama2-model",
|
||||
"litellm_params": {"model": "llama2"},
|
||||
"model_info": {"id": "id-3", "access_groups": ["beta-models"]},
|
||||
},
|
||||
{
|
||||
"model_name": "other-model",
|
||||
"litellm_params": {"model": "other"},
|
||||
"model_info": {"id": "id-4", "access_groups": ["other-group"]},
|
||||
},
|
||||
{
|
||||
"model_name": "no-access-group-model",
|
||||
"litellm_params": {"model": "no-access"},
|
||||
"model_info": {"id": "id-5"},
|
||||
},
|
||||
]
|
||||
}
|
||||
import yaml as pyyaml
|
||||
|
||||
yaml_file = tmp_path / "models.yaml"
|
||||
with open(yaml_file, "w") as f:
|
||||
pyyaml.safe_dump(yaml_content, f)
|
||||
|
||||
# Patch client.models.new to track calls
|
||||
mock_new = mock_client.return_value.models.new
|
||||
|
||||
# Only match models with access_groups containing 'beta'
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"models",
|
||||
"import",
|
||||
str(yaml_file),
|
||||
"--only-access-groups-matching-regex",
|
||||
"beta",
|
||||
],
|
||||
)
|
||||
|
||||
# Should succeed
|
||||
assert result.exit_code == 0
|
||||
# Only the two models with 'beta-models' in access_groups should be imported
|
||||
calls = [call.kwargs["model_params"]["model"] for call in mock_new.call_args_list]
|
||||
assert set(calls) == {"gpt-4", "llama2"}
|
||||
# Should not include gpt-3.5, other, or no-access
|
||||
assert "gpt-3.5-turbo" not in calls
|
||||
assert "other" not in calls
|
||||
assert "no-access" not in calls
|
||||
# Output summary should mention the correct providers
|
||||
assert "gpt-4".split("-")[0] in result.output or "gpt" in result.output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_str,expected",
|
||||
[
|
||||
(None, ""),
|
||||
("", ""),
|
||||
("2024-05-01T12:34:56Z", "2024-05-01 12:34"),
|
||||
("2024-05-01T12:34:56+00:00", "2024-05-01 12:34"),
|
||||
("2024-05-01T12:34:56.123456+00:00", "2024-05-01 12:34"),
|
||||
("2024-05-01T12:34:56.123456Z", "2024-05-01 12:34"),
|
||||
("2024-05-01T12:34:56-04:00", "2024-05-01 12:34"),
|
||||
("2024-05-01", "2024-05-01 00:00"),
|
||||
("not-a-date", "not-a-date"),
|
||||
],
|
||||
)
|
||||
def test_format_iso_datetime_str(input_str, expected):
|
||||
assert format_iso_datetime_str(input_str) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_val,expected",
|
||||
[
|
||||
(None, ""),
|
||||
(1699848889, "2023-11-13 04:14"),
|
||||
(1699848889.0, "2023-11-13 04:14"),
|
||||
("not-a-timestamp", "not-a-timestamp"),
|
||||
([1, 2, 3], "[1, 2, 3]"),
|
||||
],
|
||||
)
|
||||
def test_format_timestamp(input_val, expected, force_utc_tz):
|
||||
actual = format_timestamp(input_val)
|
||||
if actual != expected:
|
||||
print(f"input: {input_val}, expected: {expected}, actual: {actual}")
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_val,expected",
|
||||
[
|
||||
(None, ""),
|
||||
(0, "$0.0000"),
|
||||
(0.0, "$0.0000"),
|
||||
(0.00001, "$0.0100"),
|
||||
(0.00002, "$0.0200"),
|
||||
(1, "$1000.0000"),
|
||||
(1.5, "$1500.0000"),
|
||||
("0.00001", "$0.0100"),
|
||||
("1.5", "$1500.0000"),
|
||||
("not-a-number", "not-a-number"),
|
||||
(1e-10, "$0.0000"),
|
||||
],
|
||||
)
|
||||
def test_format_cost_per_1k_tokens(input_val, expected):
|
||||
actual = format_cost_per_1k_tokens(input_val)
|
||||
if actual != expected:
|
||||
print(f"input: {input_val}, expected: {expected}, actual: {actual}")
|
||||
assert actual == expected
|
@@ -0,0 +1,96 @@
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
from litellm.proxy.client.cli import cli
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cli_runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_env():
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"LITELLM_PROXY_URL": "http://localhost:4000",
|
||||
"LITELLM_PROXY_API_KEY": "sk-test",
|
||||
},
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_users_client():
|
||||
with patch(
|
||||
"litellm.proxy.client.cli.commands.users.UsersManagementClient"
|
||||
) as MockClient:
|
||||
yield MockClient
|
||||
|
||||
|
||||
def test_users_list(cli_runner, mock_users_client):
|
||||
mock_users_client.return_value.list_users.return_value = [
|
||||
{
|
||||
"user_id": "u1",
|
||||
"user_email": "a@b.com",
|
||||
"user_role": "internal_user",
|
||||
"teams": ["t1"],
|
||||
},
|
||||
{
|
||||
"user_id": "u2",
|
||||
"user_email": "b@b.com",
|
||||
"user_role": "proxy_admin",
|
||||
"teams": ["t2", "t3"],
|
||||
},
|
||||
]
|
||||
result = cli_runner.invoke(cli, ["users", "list"])
|
||||
assert result.exit_code == 0
|
||||
assert "u1" in result.output
|
||||
assert "a@b.com" in result.output
|
||||
assert "proxy_admin" in result.output
|
||||
assert "t3" in result.output
|
||||
mock_users_client.return_value.list_users.assert_called_once()
|
||||
|
||||
|
||||
def test_users_get(cli_runner, mock_users_client):
|
||||
mock_users_client.return_value.get_user.return_value = {
|
||||
"user_id": "u1",
|
||||
"user_email": "a@b.com",
|
||||
}
|
||||
result = cli_runner.invoke(cli, ["users", "get", "--id", "u1"])
|
||||
assert result.exit_code == 0
|
||||
assert '"user_id": "u1"' in result.output
|
||||
assert '"user_email": "a@b.com"' in result.output
|
||||
mock_users_client.return_value.get_user.assert_called_once_with(user_id="u1")
|
||||
|
||||
|
||||
def test_users_create(cli_runner, mock_users_client):
|
||||
mock_users_client.return_value.create_user.return_value = {
|
||||
"user_id": "u1",
|
||||
"user_email": "a@b.com",
|
||||
}
|
||||
result = cli_runner.invoke(
|
||||
cli, ["users", "create", "--email", "a@b.com", "--role", "internal_user"]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert '"user_id": "u1"' in result.output
|
||||
assert '"user_email": "a@b.com"' in result.output
|
||||
mock_users_client.return_value.create_user.assert_called_once()
|
||||
|
||||
|
||||
def test_users_delete(cli_runner, mock_users_client):
|
||||
mock_users_client.return_value.delete_user.return_value = {"deleted": 1}
|
||||
result = cli_runner.invoke(cli, ["users", "delete", "u1", "u2"])
|
||||
assert result.exit_code == 0
|
||||
assert '"deleted": 1' in result.output
|
||||
mock_users_client.return_value.delete_user.assert_called_once_with(["u1", "u2"])
|
Reference in New Issue
Block a user