596 lines
20 KiB
Python
596 lines
20 KiB
Python
"""
|
|
Test HuggingFace LLM
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from base_llm_unit_tests import BaseLLMChatTest
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
|
|
import pytest
|
|
|
|
import litellm
|
|
from litellm.types.utils import ModelResponse, ModelResponseStream
|
|
|
|
MOCK_COMPLETION_RESPONSE = {
|
|
"id": "9115d3daeab10608",
|
|
"object": "chat.completion",
|
|
"created": 11111,
|
|
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
|
|
"prompt": [],
|
|
"choices": [
|
|
{
|
|
"finish_reason": "stop",
|
|
"seed": 3629048360264764400,
|
|
"logprobs": None,
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "This is a test response from the mocked HuggingFace API.",
|
|
"tool_calls": [],
|
|
},
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
|
}
|
|
|
|
MOCK_STREAMING_CHUNKS = [
|
|
{
|
|
"id": "id1",
|
|
"object": "chat.completion.chunk",
|
|
"created": 1111,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"text": "Deep",
|
|
"logprobs": None,
|
|
"finish_reason": None,
|
|
"seed": None,
|
|
"delta": {
|
|
"token_id": 34564,
|
|
"role": "assistant",
|
|
"content": "Deep",
|
|
"tool_calls": None,
|
|
},
|
|
}
|
|
],
|
|
"model": "meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
|
|
"usage": None,
|
|
},
|
|
{
|
|
"id": "id2",
|
|
"object": "chat.completion.chunk",
|
|
"created": 1111,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"text": " learning",
|
|
"logprobs": None,
|
|
"finish_reason": None,
|
|
"seed": None,
|
|
"delta": {
|
|
"token_id": 6975,
|
|
"role": "assistant",
|
|
"content": " learning",
|
|
"tool_calls": None,
|
|
},
|
|
}
|
|
],
|
|
"model": "meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
|
|
"usage": None,
|
|
},
|
|
{
|
|
"id": "id3",
|
|
"object": "chat.completion.chunk",
|
|
"created": 1111,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"text": " is",
|
|
"logprobs": None,
|
|
"finish_reason": None,
|
|
"seed": None,
|
|
"delta": {
|
|
"token_id": 374,
|
|
"role": "assistant",
|
|
"content": " is",
|
|
"tool_calls": None,
|
|
},
|
|
}
|
|
],
|
|
"model": "meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
|
|
"usage": None,
|
|
},
|
|
{
|
|
"id": "sid4",
|
|
"object": "chat.completion.chunk",
|
|
"created": 1111,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"text": " response",
|
|
"logprobs": None,
|
|
"finish_reason": "length",
|
|
"seed": 2853637492034609700,
|
|
"delta": {
|
|
"token_id": 323,
|
|
"role": "assistant",
|
|
"content": " response",
|
|
"tool_calls": None,
|
|
},
|
|
}
|
|
],
|
|
"model": "meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
|
|
"usage": {"prompt_tokens": 26, "completion_tokens": 20, "total_tokens": 46},
|
|
},
|
|
]
|
|
|
|
|
|
PROVIDER_MAPPING_RESPONSE = {
|
|
"fireworks-ai": {
|
|
"status": "live",
|
|
"providerId": "accounts/fireworks/models/llama-v3-8b-instruct",
|
|
"task": "conversational",
|
|
},
|
|
"together": {
|
|
"status": "live",
|
|
"providerId": "meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
|
|
"task": "conversational",
|
|
},
|
|
"hf-inference": {
|
|
"status": "live",
|
|
"providerId": "meta-llama/Meta-Llama-3-8B-Instruct",
|
|
"task": "conversational",
|
|
},
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_provider_mapping():
|
|
with patch(
|
|
"litellm.llms.huggingface.chat.transformation._fetch_inference_provider_mapping"
|
|
) as mock:
|
|
mock.return_value = PROVIDER_MAPPING_RESPONSE
|
|
yield mock
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clear_lru_cache():
|
|
from litellm.llms.huggingface.common_utils import _fetch_inference_provider_mapping
|
|
|
|
_fetch_inference_provider_mapping.cache_clear()
|
|
yield
|
|
_fetch_inference_provider_mapping.cache_clear()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_http_handler():
|
|
"""Fixture to mock the HTTP handler"""
|
|
with patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post") as mock:
|
|
print(f"Creating mock HTTP handler: {mock}") # noqa: T201
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.raise_for_status.return_value = None
|
|
mock_response.status_code = 200
|
|
|
|
def mock_side_effect(*args, **kwargs):
|
|
if kwargs.get("stream", True):
|
|
mock_response.iter_lines.return_value = iter(
|
|
[
|
|
f"data: {json.dumps(chunk)}".encode("utf-8")
|
|
for chunk in MOCK_STREAMING_CHUNKS
|
|
]
|
|
+ [b"data: [DONE]"]
|
|
)
|
|
else:
|
|
mock_response.json.return_value = MOCK_COMPLETION_RESPONSE
|
|
return mock_response
|
|
|
|
mock.side_effect = mock_side_effect
|
|
yield mock
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_http_async_handler():
|
|
"""Fixture to mock the async HTTP handler"""
|
|
with patch(
|
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
|
new_callable=AsyncMock,
|
|
) as mock:
|
|
print(f"Creating mock async HTTP handler: {mock}") # noqa: T201
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.raise_for_status.return_value = None
|
|
mock_response.status_code = 200
|
|
mock_response.headers = {"content-type": "application/json"}
|
|
|
|
mock_response.json.return_value = MOCK_COMPLETION_RESPONSE
|
|
mock_response.text = json.dumps(MOCK_COMPLETION_RESPONSE)
|
|
|
|
async def mock_side_effect(*args, **kwargs):
|
|
if kwargs.get("stream", True):
|
|
|
|
async def mock_aiter():
|
|
for chunk in MOCK_STREAMING_CHUNKS:
|
|
yield f"data: {json.dumps(chunk)}".encode("utf-8")
|
|
yield b"data: [DONE]"
|
|
|
|
mock_response.aiter_lines = mock_aiter
|
|
return mock_response
|
|
|
|
mock.side_effect = mock_side_effect
|
|
yield mock
|
|
|
|
|
|
class TestHuggingFace(BaseLLMChatTest):
|
|
@pytest.fixture(autouse=True)
|
|
def setup(self, mock_provider_mapping, mock_http_handler, mock_http_async_handler):
|
|
self.mock_provider_mapping = mock_provider_mapping
|
|
self.mock_http = mock_http_handler
|
|
self.mock_http_async = mock_http_async_handler
|
|
self.model = "huggingface/together/meta-llama/Meta-Llama-3-8B-Instruct"
|
|
litellm.set_verbose = False
|
|
|
|
def get_base_completion_call_args(self) -> dict:
|
|
"""Implementation of abstract method from BaseLLMChatTest"""
|
|
return {"model": self.model}
|
|
|
|
def test_completion_non_streaming(self):
|
|
messages = [{"role": "user", "content": "This is a dummy message"}]
|
|
|
|
response = litellm.completion(model=self.model, messages=messages, stream=False)
|
|
assert isinstance(response, ModelResponse)
|
|
assert (
|
|
response.choices[0].message.content
|
|
== "This is a test response from the mocked HuggingFace API."
|
|
)
|
|
assert response.usage is not None
|
|
assert response.model == self.model.split("/", 2)[2]
|
|
|
|
def test_completion_streaming(self):
|
|
messages = [{"role": "user", "content": "This is a dummy message"}]
|
|
|
|
response = litellm.completion(model=self.model, messages=messages, stream=True)
|
|
|
|
chunks = list(response)
|
|
assert len(chunks) > 0
|
|
|
|
assert self.mock_http.called
|
|
call_args = self.mock_http.call_args
|
|
assert call_args is not None
|
|
|
|
kwargs = call_args[1]
|
|
data = json.loads(kwargs["data"])
|
|
assert data["stream"] is True
|
|
assert data["messages"] == messages
|
|
|
|
assert isinstance(chunks, list)
|
|
assert isinstance(chunks[0], ModelResponseStream)
|
|
assert isinstance(chunks[0].id, str)
|
|
assert chunks[0].model == self.model.split("/", 1)[1]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_completion_streaming(self):
|
|
"""Test async streaming completion"""
|
|
messages = [{"role": "user", "content": "This is a dummy message"}]
|
|
response = await litellm.acompletion(
|
|
model=self.model, messages=messages, stream=True
|
|
)
|
|
|
|
chunks = []
|
|
async for chunk in response:
|
|
chunks.append(chunk)
|
|
|
|
assert self.mock_http_async.called
|
|
assert len(chunks) > 0
|
|
assert isinstance(chunks[0], ModelResponseStream)
|
|
assert isinstance(chunks[0].id, str)
|
|
assert chunks[0].model == self.model.split("/", 1)[1]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_completion_non_streaming(self):
|
|
"""Test async non-streaming completion"""
|
|
messages = [{"role": "user", "content": "This is a dummy message"}]
|
|
response = await litellm.acompletion(
|
|
model=self.model, messages=messages, stream=False
|
|
)
|
|
|
|
assert self.mock_http_async.called
|
|
assert isinstance(response, ModelResponse)
|
|
assert (
|
|
response.choices[0].message.content
|
|
== "This is a test response from the mocked HuggingFace API."
|
|
)
|
|
assert response.usage is not None
|
|
assert response.model == self.model.split("/", 2)[2]
|
|
|
|
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
|
mock_tool_response = {
|
|
**MOCK_COMPLETION_RESPONSE,
|
|
"choices": [
|
|
{
|
|
"finish_reason": "tool_calls",
|
|
"index": 0,
|
|
"message": tool_call_no_arguments,
|
|
}
|
|
],
|
|
}
|
|
|
|
with patch.object(
|
|
self.mock_http,
|
|
"side_effect",
|
|
lambda *args, **kwargs: MagicMock(
|
|
status_code=200,
|
|
json=lambda: mock_tool_response,
|
|
raise_for_status=lambda: None,
|
|
),
|
|
):
|
|
messages = [{"role": "user", "content": "Get the FAQ"}]
|
|
tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "Get-FAQ",
|
|
"description": "Get FAQ information",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {},
|
|
"required": [],
|
|
},
|
|
},
|
|
}
|
|
]
|
|
|
|
response = litellm.completion(
|
|
model=self.model, messages=messages, tools=tools, tool_choice="auto"
|
|
)
|
|
|
|
assert response.choices[0].message.tool_calls is not None
|
|
assert len(response.choices[0].message.tool_calls) == 1
|
|
assert (
|
|
response.choices[0].message.tool_calls[0].function.name
|
|
== tool_call_no_arguments["tool_calls"][0]["function"]["name"]
|
|
)
|
|
assert (
|
|
response.choices[0].message.tool_calls[0].function.arguments
|
|
== tool_call_no_arguments["tool_calls"][0]["function"]["arguments"]
|
|
)
|
|
|
|
@pytest.mark.parametrize(
|
|
"model, expected_url",
|
|
[
|
|
(
|
|
"meta-llama/Llama-3-8B-Instruct",
|
|
"https://router.huggingface.co/v1/chat/completions",
|
|
),
|
|
(
|
|
"together/meta-llama/Llama-3-8B-Instruct",
|
|
"https://router.huggingface.co/together/v1/chat/completions",
|
|
),
|
|
(
|
|
"novita/meta-llama/Llama-3-8B-Instruct",
|
|
"https://router.huggingface.co/novita/v3/openai/chat/completions",
|
|
),
|
|
(
|
|
"http://custom-endpoint.com/v1/chat/completions",
|
|
"http://custom-endpoint.com/v1/chat/completions",
|
|
),
|
|
],
|
|
)
|
|
def test_get_complete_url(self, model, expected_url):
|
|
"""Test that the complete URL is constructed correctly for different providers"""
|
|
from litellm.llms.huggingface.chat.transformation import HuggingFaceChatConfig
|
|
|
|
config = HuggingFaceChatConfig()
|
|
url = config.get_complete_url(
|
|
api_base=None,
|
|
model=model,
|
|
optional_params={},
|
|
stream=False,
|
|
api_key="test_api_key",
|
|
litellm_params={},
|
|
)
|
|
assert url == expected_url
|
|
|
|
@pytest.mark.parametrize(
|
|
"api_base, model, expected_url",
|
|
[
|
|
(
|
|
"https://abcd123.us-east-1.aws.endpoints.huggingface.cloud",
|
|
"huggingface/tgi",
|
|
"https://abcd123.us-east-1.aws.endpoints.huggingface.cloud/v1/chat/completions",
|
|
),
|
|
(
|
|
"https://abcd123.us-east-1.aws.endpoints.huggingface.cloud/",
|
|
"huggingface/tgi",
|
|
"https://abcd123.us-east-1.aws.endpoints.huggingface.cloud/v1/chat/completions",
|
|
),
|
|
(
|
|
"https://abcd123.us-east-1.aws.endpoints.huggingface.cloud/v1/chat/completions",
|
|
"huggingface/tgi",
|
|
"https://abcd123.us-east-1.aws.endpoints.huggingface.cloud/v1/chat/completions",
|
|
),
|
|
(
|
|
"https://example.com/custom/path",
|
|
"huggingface/tgi",
|
|
"https://example.com/custom/path/v1/chat/completions",
|
|
),
|
|
(
|
|
"https://example.com/custom/path/v1/chat/completions",
|
|
"huggingface/tgi",
|
|
"https://example.com/custom/path/v1/chat/completions",
|
|
),
|
|
(
|
|
"https://example.com/v1",
|
|
"huggingface/tgi",
|
|
"https://example.com/v1/chat/completions",
|
|
),
|
|
],
|
|
)
|
|
def test_get_complete_url_inference_endpoints(self, api_base, model, expected_url):
|
|
from litellm.llms.huggingface.chat.transformation import HuggingFaceChatConfig
|
|
|
|
config = HuggingFaceChatConfig()
|
|
url = config.get_complete_url(
|
|
api_base=api_base,
|
|
model=model,
|
|
optional_params={},
|
|
stream=False,
|
|
api_key="test_api_key",
|
|
litellm_params={},
|
|
)
|
|
assert url == expected_url
|
|
|
|
def test_completion_with_api_base(self):
|
|
messages = [{"role": "user", "content": "This is a test message"}]
|
|
api_base = "https://abcd123.us-east-1.aws.endpoints.huggingface.cloud"
|
|
|
|
response = litellm.completion(
|
|
model="huggingface/tgi", messages=messages, api_base=api_base, stream=False
|
|
)
|
|
|
|
assert isinstance(response, ModelResponse)
|
|
assert (
|
|
response.choices[0].message.content
|
|
== "This is a test response from the mocked HuggingFace API."
|
|
)
|
|
|
|
assert self.mock_http.called
|
|
call_args = self.mock_http.call_args
|
|
assert call_args is not None
|
|
|
|
called_url = call_args[1]["url"]
|
|
assert called_url == f"{api_base}/v1/chat/completions"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_completion_with_api_base(self):
|
|
messages = [{"role": "user", "content": "This is a test message"}]
|
|
api_base = "https://abcd123.us-east-1.aws.endpoints.huggingface.cloud"
|
|
|
|
response = await litellm.acompletion(
|
|
model="huggingface/tgi", messages=messages, api_base=api_base, stream=False
|
|
)
|
|
|
|
assert isinstance(response, ModelResponse)
|
|
assert (
|
|
response.choices[0].message.content
|
|
== "This is a test response from the mocked HuggingFace API."
|
|
)
|
|
|
|
assert self.mock_http_async.called
|
|
call_args = self.mock_http_async.call_args
|
|
assert call_args is not None
|
|
|
|
called_url = call_args[1]["url"]
|
|
assert called_url == f"{api_base}/v1/chat/completions"
|
|
|
|
def test_completion_streaming_with_api_base(self):
|
|
"""Test streaming completion with api_base parameter"""
|
|
messages = [{"role": "user", "content": "This is a test message"}]
|
|
api_base = "https://abcd123.us-east-1.aws.endpoints.huggingface.cloud"
|
|
|
|
response = litellm.completion(
|
|
model="huggingface/tgi", messages=messages, api_base=api_base, stream=True
|
|
)
|
|
|
|
chunks = list(response)
|
|
assert len(chunks) > 0
|
|
assert isinstance(chunks[0], ModelResponseStream)
|
|
|
|
# Check that the correct URL was called
|
|
assert self.mock_http.called
|
|
call_args = self.mock_http.call_args
|
|
assert call_args is not None
|
|
|
|
called_url = call_args[1]["url"]
|
|
assert called_url == f"{api_base}/v1/chat/completions"
|
|
|
|
def test_build_chat_completion_url_function(self):
|
|
"""Test the _build_chat_completion_url helper function"""
|
|
from litellm.llms.huggingface.chat.transformation import (
|
|
_build_chat_completion_url,
|
|
)
|
|
|
|
test_cases = [
|
|
("https://example.com", "https://example.com/v1/chat/completions"),
|
|
("https://example.com/", "https://example.com/v1/chat/completions"),
|
|
("https://example.com/v1", "https://example.com/v1/chat/completions"),
|
|
("https://example.com/v1/", "https://example.com/v1/chat/completions"),
|
|
(
|
|
"https://example.com/v1/chat/completions",
|
|
"https://example.com/v1/chat/completions",
|
|
),
|
|
(
|
|
"https://example.com/custom/path",
|
|
"https://example.com/custom/path/v1/chat/completions",
|
|
),
|
|
(
|
|
"https://example.com/custom/path/",
|
|
"https://example.com/custom/path/v1/chat/completions",
|
|
),
|
|
]
|
|
|
|
for input_url, expected_url in test_cases:
|
|
result = _build_chat_completion_url(input_url)
|
|
assert (
|
|
result == expected_url
|
|
), f"Failed for input: {input_url}, expected: {expected_url}, got: {result}"
|
|
|
|
def test_validate_environment(self):
|
|
"""Test that the environment is validated correctly"""
|
|
from litellm.llms.huggingface.chat.transformation import HuggingFaceChatConfig
|
|
|
|
config = HuggingFaceChatConfig()
|
|
|
|
headers = config.validate_environment(
|
|
headers={},
|
|
model="huggingface/fireworks-ai/meta-llama/Meta-Llama-3-8B-Instruct",
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
optional_params={},
|
|
api_key="test_api_key",
|
|
litellm_params={},
|
|
)
|
|
|
|
assert headers["Authorization"] == "Bearer test_api_key"
|
|
assert headers["content-type"] == "application/json"
|
|
|
|
@pytest.mark.parametrize(
|
|
"model, expected_model",
|
|
[
|
|
(
|
|
"together/meta-llama/Llama-3-8B-Instruct",
|
|
"meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
|
|
),
|
|
(
|
|
"meta-llama/Meta-Llama-3-8B-Instruct",
|
|
"meta-llama/Meta-Llama-3-8B-Instruct",
|
|
),
|
|
],
|
|
)
|
|
def test_transform_request(self, model, expected_model):
|
|
from litellm.llms.huggingface.chat.transformation import HuggingFaceChatConfig
|
|
|
|
config = HuggingFaceChatConfig()
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
|
|
transformed_request = config.transform_request(
|
|
model=model,
|
|
messages=messages,
|
|
optional_params={},
|
|
litellm_params={},
|
|
headers={},
|
|
)
|
|
|
|
assert transformed_request["model"] == expected_model
|
|
assert transformed_request["messages"] == messages
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_completion_cost(self):
|
|
pass
|