Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,103 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
||||
|
||||
|
||||
def test_hosted_vllm_chat_transformation_file_url():
|
||||
config = HostedVLLMChatConfig()
|
||||
video_url = "https://example.com/video.mp4"
|
||||
video_data = f"data:video/mp4;base64,{video_url}"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "file",
|
||||
"file": {
|
||||
"file_data": video_data,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
transformed_response = config.transform_request(
|
||||
model="hosted_vllm/llama-3.1-70b-instruct",
|
||||
messages=messages,
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
headers={},
|
||||
)
|
||||
assert transformed_response["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "video_url", "video_url": {"url": video_data}}],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_hosted_vllm_chat_transformation_with_audio_url():
|
||||
from litellm import completion
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
client = MagicMock()
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create", return_value=MagicMock()
|
||||
) as mock_post:
|
||||
try:
|
||||
response = completion(
|
||||
model="hosted_vllm/llama-3.1-70b-instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {"url": "https://example.com/audio.mp3"},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
mock_post.assert_called_once()
|
||||
print(f"mock_post.call_args.kwargs: {mock_post.call_args.kwargs}")
|
||||
assert mock_post.call_args.kwargs["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {"url": "https://example.com/audio.mp3"},
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_hosted_vllm_supports_reasoning_effort():
|
||||
config = HostedVLLMChatConfig()
|
||||
supported_params = config.get_supported_openai_params(
|
||||
model="hosted_vllm/gpt-oss-120b"
|
||||
)
|
||||
assert "reasoning_effort" in supported_params
|
||||
optional_params = config.map_openai_params(
|
||||
non_default_params={"reasoning_effort": "high"},
|
||||
optional_params={},
|
||||
model="hosted_vllm/gpt-oss-120b",
|
||||
drop_params=False,
|
||||
)
|
||||
assert optional_params["reasoning_effort"] == "high"
|
@@ -0,0 +1,81 @@
|
||||
import sys
|
||||
import os
|
||||
import pytest
|
||||
from litellm.llms.hosted_vllm.rerank.transformation import HostedVLLMRerankConfig
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankResponse, RerankResponseResult, RerankResponseMeta, RerankBilledUnits, RerankTokens, RerankResponseDocument
|
||||
|
||||
class TestHostedVLLMRerankTransform:
|
||||
def setup_method(self):
|
||||
self.config = HostedVLLMRerankConfig()
|
||||
self.model = "hosted-vllm-model"
|
||||
|
||||
def test_map_cohere_rerank_params_basic(self):
|
||||
params = self.config.map_cohere_rerank_params(
|
||||
non_default_params=None,
|
||||
model=self.model,
|
||||
drop_params=False,
|
||||
query="test query",
|
||||
documents=["doc1", "doc2"],
|
||||
top_n=2,
|
||||
rank_fields=["field1"],
|
||||
return_documents=True,
|
||||
)
|
||||
assert params["query"] == "test query"
|
||||
assert params["documents"] == ["doc1", "doc2"]
|
||||
assert params["top_n"] == 2
|
||||
assert params["rank_fields"] == ["field1"]
|
||||
assert params["return_documents"] is True
|
||||
|
||||
def test_map_cohere_rerank_params_raises_on_max_chunks_per_doc(self):
|
||||
with pytest.raises(ValueError, match="Hosted VLLM does not support max_chunks_per_doc"):
|
||||
self.config.map_cohere_rerank_params(
|
||||
non_default_params=None,
|
||||
model=self.model,
|
||||
drop_params=False,
|
||||
query="test query",
|
||||
documents=["doc1"],
|
||||
max_chunks_per_doc=5
|
||||
)
|
||||
|
||||
def test_get_complete_url(self):
|
||||
base = "https://api.example.com"
|
||||
url = self.config.get_complete_url(base, self.model)
|
||||
assert url == "https://api.example.com/v1/rerank"
|
||||
# Already ends with /v1/rerank
|
||||
url2 = self.config.get_complete_url("https://api.example.com/v1/rerank", self.model)
|
||||
assert url2 == "https://api.example.com/v1/rerank"
|
||||
# Raises if api_base is None
|
||||
with pytest.raises(ValueError):
|
||||
self.config.get_complete_url(None, self.model)
|
||||
|
||||
def test_transform_response(self):
|
||||
response_dict = {
|
||||
"id": "abc123",
|
||||
"results": [
|
||||
{"index": 0, "relevance_score": 0.9, "document": {"text": "doc1 text"}},
|
||||
{"index": 1, "relevance_score": 0.7, "document": {"text": "doc2 text"}},
|
||||
],
|
||||
"usage": {"total_tokens": 42}
|
||||
}
|
||||
result = self.config._transform_response(response_dict)
|
||||
assert result.id == "abc123"
|
||||
assert len(result.results) == 2
|
||||
assert result.results[0]["index"] == 0
|
||||
assert result.results[0]["relevance_score"] == 0.9
|
||||
assert result.results[0]["document"]["text"] == "doc1 text"
|
||||
assert result.meta["billed_units"]["total_tokens"] == 42
|
||||
assert result.meta["tokens"]["input_tokens"] == 42
|
||||
|
||||
def test_transform_response_missing_results(self):
|
||||
response_dict = {"id": "abc123", "usage": {"total_tokens": 10}}
|
||||
with pytest.raises(ValueError, match="No results found in the response="):
|
||||
self.config._transform_response(response_dict)
|
||||
|
||||
def test_transform_response_missing_required_fields(self):
|
||||
response_dict = {
|
||||
"id": "abc123",
|
||||
"results": [{"relevance_score": 0.5}],
|
||||
"usage": {"total_tokens": 10}
|
||||
}
|
||||
with pytest.raises(ValueError, match="Missing required fields in the result="):
|
||||
self.config._transform_response(response_dict)
|
Reference in New Issue
Block a user