Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,69 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.llms.datarobot.chat.transformation import DataRobotConfig
|
||||
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
class TestDataRobotConfig:
|
||||
@pytest.fixture
|
||||
def handler(self):
|
||||
return DataRobotConfig()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_base, expected_url",
|
||||
[
|
||||
(None, "https://app.datarobot.com/api/v2/genai/llmgw/chat/completions/"),
|
||||
("http://localhost:5001", "http://localhost:5001/api/v2/genai/llmgw/chat/completions/"),
|
||||
("https://app.datarobot.com", "https://app.datarobot.com/api/v2/genai/llmgw/chat/completions/"),
|
||||
("https://app.datarobot.com/api/v2/genai/llmgw/chat/completions", "https://app.datarobot.com/api/v2/genai/llmgw/chat/completions/"),
|
||||
("https://app.datarobot.com/api/v2/genai/llmgw/chat/completions/", "https://app.datarobot.com/api/v2/genai/llmgw/chat/completions/"),
|
||||
("https://staging.datarobot.com", "https://staging.datarobot.com/api/v2/genai/llmgw/chat/completions/"),
|
||||
("https://app.datarobot.com/api/v2/deployments/deployment_id", "https://app.datarobot.com/api/v2/deployments/deployment_id/"),
|
||||
("https://app.datarobot.com/api/v2/deployments/deployment_id/", "https://app.datarobot.com/api/v2/deployments/deployment_id/"),
|
||||
]
|
||||
)
|
||||
def test_resolve_api_base(self, api_base, expected_url, handler):
|
||||
"""Test that URLs properly resolve to the expected format."""
|
||||
assert handler._resolve_api_base(api_base) == expected_url
|
||||
|
||||
# Check that the complete url with the resolution is expected
|
||||
assert handler.get_complete_url(
|
||||
api_base=handler._resolve_api_base(api_base),
|
||||
api_key="PASSTHROUGH_KEY",
|
||||
model="datarobot/vertex_ai/gemini-1.5-flash-002",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
) == expected_url
|
||||
|
||||
# Check that the complete url with the original api_base does not change the url
|
||||
if api_base is not None:
|
||||
assert handler.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key="PASSTHROUGH_KEY",
|
||||
model="datarobot/vertex_ai/gemini-1.5-flash-002",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
) == api_base
|
||||
|
||||
def test_resolve_api_base_with_environment_variable(self, handler):
|
||||
os.environ["DATAROBOT_ENDPOINT"] = "https://env.datarobot.com"
|
||||
assert handler._resolve_api_base(None) == "https://env.datarobot.com/api/v2/genai/llmgw/chat/completions/"
|
||||
del os.environ["DATAROBOT_ENDPOINT"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_key, expected_api_key",
|
||||
[
|
||||
(None, "fake-api-key"),
|
||||
("PASSTHROUGH_KEY", "PASSTHROUGH_KEY"),
|
||||
]
|
||||
)
|
||||
def test_resolve_api_key(self, api_key, expected_api_key, handler):
|
||||
assert handler._resolve_api_key(api_key) == expected_api_key
|
||||
|
||||
def test_resolve_api_key_with_environment_variable(self, handler):
|
||||
os.environ["DATAROBOT_API_TOKEN"] = "env_key"
|
||||
assert handler._resolve_api_key(None) == "env_key"
|
||||
del os.environ["DATAROBOT_API_TOKEN"]
|
@@ -0,0 +1,85 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm import completion
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_completion_datarobot():
|
||||
"""Ensure that the completion function works with DataRobot API."""
|
||||
messages = [{"role": "user", "content": "What's the weather like in San Francisco?"}]
|
||||
try:
|
||||
client = HTTPHandler()
|
||||
with patch.object(client, "post") as mock_post:
|
||||
response = completion(
|
||||
model="datarobot/vertex_ai/gemini-1.5-flash-002",
|
||||
messages=messages,
|
||||
client=client,
|
||||
max_tokens=5,
|
||||
clientId="custom-model",
|
||||
)
|
||||
print(response)
|
||||
|
||||
# Add any assertions here to check the response
|
||||
mock_post.assert_called_once()
|
||||
mocks_kwargs = mock_post.call_args.kwargs
|
||||
assert mocks_kwargs["url"] == "https://app.datarobot.com/api/v2/genai/llmgw/chat/completions/"
|
||||
assert mocks_kwargs["headers"]["Authorization"] == "Bearer fake-api-key"
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
assert json_data["clientId"] == "custom-model"
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@patch.dict(
|
||||
os.environ, {"DATAROBOT_ENDPOINT": "https://app.datarobot.com/api/v2/deployments/deployment_id/"}, clear=True
|
||||
)
|
||||
def test_completion_datarobot_with_deployment():
|
||||
"""Ensure that deployment URL is used correctly."""
|
||||
messages = [{"role": "user", "content": "What's the weather like in San Francisco?"}]
|
||||
try:
|
||||
client = HTTPHandler()
|
||||
with patch.object(client, "post") as mock_post:
|
||||
response = completion(
|
||||
model="datarobot/vertex_ai/gemini-1.5-flash-002",
|
||||
messages=messages,
|
||||
client=client,
|
||||
max_tokens=5,
|
||||
clientId="custom-model",
|
||||
)
|
||||
print(response)
|
||||
|
||||
# Add any assertions here to check the response
|
||||
mock_post.assert_called_once()
|
||||
mocks_kwargs = mock_post.call_args.kwargs
|
||||
assert mocks_kwargs["url"] == "https://app.datarobot.com/api/v2/deployments/deployment_id/"
|
||||
assert mocks_kwargs["headers"]["Authorization"] == "Bearer fake-api-key"
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
assert json_data["clientId"] == "custom-model"
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_datarobot_with_environment_variables():
|
||||
"""Allow the test to run with environment variables if they are set for integrations."""
|
||||
# If keys are not set, the test will be skipped
|
||||
if os.environ.get("DATAROBOT_API_TOKEN") is None:
|
||||
return
|
||||
|
||||
messages = [{"role": "user", "content": "What's the weather like in San Francisco?"}]
|
||||
try:
|
||||
response = completion(
|
||||
model="datarobot/vertex_ai/gemini-1.5-flash-002", messages=messages, max_tokens=5, clientId="custom-model"
|
||||
)
|
||||
print(response)
|
||||
assert response["object"] == "chat.completion"
|
||||
assert response["model"] == "gemini-1.5-flash-002"
|
||||
assert len(response["choices"]) == 1
|
||||
assert len(response["choices"][0]["message"]["content"]) > 0
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
Reference in New Issue
Block a user