Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,205 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import completion
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from unittest.mock import patch, Mock
|
||||
import pytest
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def watsonx_chat_completion_call():
|
||||
def _call(
|
||||
model="watsonx/my-test-model",
|
||||
messages=None,
|
||||
api_key="test_api_key",
|
||||
space_id: Optional[str] = None,
|
||||
headers=None,
|
||||
client=None,
|
||||
patch_token_call=True,
|
||||
):
|
||||
if messages is None:
|
||||
messages = [{"role": "user", "content": "Hello, how are you?"}]
|
||||
if client is None:
|
||||
client = HTTPHandler()
|
||||
|
||||
if patch_token_call:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "mock_access_token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
mock_response.raise_for_status = Mock() # No-op to simulate no exception
|
||||
|
||||
with patch.object(client, "post") as mock_post, patch.object(
|
||||
litellm.module_level_client, "post", return_value=mock_response
|
||||
) as mock_get:
|
||||
try:
|
||||
completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
headers=headers or {},
|
||||
client=client,
|
||||
space_id=space_id,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
return mock_post, mock_get
|
||||
else:
|
||||
with patch.object(client, "post") as mock_post:
|
||||
try:
|
||||
completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
headers=headers or {},
|
||||
client=client,
|
||||
space_id=space_id,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return mock_post, None
|
||||
|
||||
return _call
|
||||
|
||||
|
||||
def test_watsonx_deployment_model_id_not_in_payload(
|
||||
monkeypatch, watsonx_chat_completion_call
|
||||
):
|
||||
"""Test that deployment models do not include 'model_id' in the request payload"""
|
||||
monkeypatch.setenv("WATSONX_PROJECT_ID", "test-project-id")
|
||||
monkeypatch.setenv("WATSONX_API_BASE", "https://test-api.watsonx.ai")
|
||||
model = "watsonx/deployment/test-deployment-id"
|
||||
messages = [{"role": "user", "content": "Test message"}]
|
||||
|
||||
mock_post, _ = watsonx_chat_completion_call(model=model, messages=messages)
|
||||
|
||||
assert mock_post.call_count == 1
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
# Ensure model_id is not in the payload for deployment models
|
||||
assert "model_id" not in json_data or json_data["model_id"] is None
|
||||
# Ensure project_id is also not in the payload for deployment models
|
||||
assert "project_id" not in json_data or json_data["project_id"] is None
|
||||
|
||||
|
||||
def test_watsonx_regular_model_includes_model_id(
|
||||
monkeypatch, watsonx_chat_completion_call
|
||||
):
|
||||
"""Test that regular models include 'model_id' in the request payload"""
|
||||
monkeypatch.setenv("WATSONX_PROJECT_ID", "test-project-id")
|
||||
monkeypatch.setenv("WATSONX_API_BASE", "https://test-api.watsonx.ai")
|
||||
model = "watsonx/regular-model"
|
||||
messages = [{"role": "user", "content": "Test message"}]
|
||||
|
||||
mock_post, _ = watsonx_chat_completion_call(model=model, messages=messages)
|
||||
|
||||
assert mock_post.call_count == 1
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
# Ensure model_id is included in the payload for regular models
|
||||
assert "model_id" in json_data
|
||||
assert json_data["model_id"] == "regular-model" # Provider prefix is stripped
|
||||
# Ensure project_id is also included for regular models
|
||||
assert "project_id" in json_data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def watsonx_completion_call():
|
||||
def _call(
|
||||
model="watsonx_text/my-test-model",
|
||||
prompt="Hello, how are you?",
|
||||
api_key="test_api_key",
|
||||
space_id: Optional[str] = None,
|
||||
headers=None,
|
||||
client=None,
|
||||
patch_token_call=True,
|
||||
):
|
||||
if client is None:
|
||||
client = HTTPHandler()
|
||||
|
||||
if patch_token_call:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "mock_access_token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
with patch.object(client, "post") as mock_post, patch.object(
|
||||
litellm.module_level_client, "post", return_value=mock_response
|
||||
) as mock_get:
|
||||
try:
|
||||
litellm.text_completion(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
api_key=api_key,
|
||||
headers=headers or {},
|
||||
client=client,
|
||||
space_id=space_id,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
return mock_post, mock_get
|
||||
else:
|
||||
with patch.object(client, "post") as mock_post:
|
||||
try:
|
||||
litellm.text_completion(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
api_key=api_key,
|
||||
headers=headers or {},
|
||||
client=client,
|
||||
space_id=space_id,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return mock_post, None
|
||||
|
||||
return _call
|
||||
|
||||
|
||||
def test_watsonx_completion_deployment_model_id_not_in_payload(
|
||||
monkeypatch, watsonx_completion_call
|
||||
):
|
||||
"""Test that deployment models do not include 'model_id' in completion request payload"""
|
||||
monkeypatch.setenv("WATSONX_PROJECT_ID", "test-project-id")
|
||||
monkeypatch.setenv("WATSONX_API_BASE", "https://test-api.watsonx.ai")
|
||||
model = "watsonx_text/deployment/test-deployment-id"
|
||||
prompt = "Test prompt"
|
||||
|
||||
mock_post, _ = watsonx_completion_call(model=model, prompt=prompt)
|
||||
|
||||
assert mock_post.call_count == 1
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
# Ensure model_id is not in the payload for deployment models
|
||||
assert "model_id" not in json_data
|
||||
# Ensure project_id is also not in the payload for deployment models
|
||||
assert "project_id" not in json_data
|
||||
|
||||
|
||||
def test_watsonx_completion_regular_model_includes_model_id(
|
||||
monkeypatch, watsonx_completion_call
|
||||
):
|
||||
"""Test that regular models include 'model_id' in completion request payload"""
|
||||
monkeypatch.setenv("WATSONX_PROJECT_ID", "test-project-id")
|
||||
monkeypatch.setenv("WATSONX_API_BASE", "https://test-api.watsonx.ai")
|
||||
model = "watsonx_text/regular-model"
|
||||
prompt = "Test prompt"
|
||||
|
||||
mock_post, _ = watsonx_completion_call(model=model, prompt=prompt)
|
||||
|
||||
assert mock_post.call_count == 1
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
# Ensure model_id is included in the payload for regular models
|
||||
assert "model_id" in json_data
|
||||
assert json_data["model_id"] == "regular-model" # Provider prefix is stripped
|
||||
# Ensure project_id is also included for regular models
|
||||
assert "project_id" in json_data
|
Reference in New Issue
Block a user