Added LiteLLM to the stack
This commit is contained in:
253
Development/litellm/tests/llm_translation/test_triton.py
Normal file
253
Development/litellm/tests/llm_translation/test_triton.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
|
||||
import pytest
|
||||
from litellm.llms.triton.embedding.transformation import TritonEmbeddingConfig
|
||||
import litellm
|
||||
|
||||
|
||||
|
||||
def test_split_embedding_by_shape_passes():
|
||||
try:
|
||||
data = [
|
||||
{
|
||||
"shape": [2, 3],
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
}
|
||||
]
|
||||
split_output_data = TritonEmbeddingConfig.split_embedding_by_shape(
|
||||
data[0]["data"], data[0]["shape"]
|
||||
)
|
||||
assert split_output_data == [[1, 2, 3], [4, 5, 6]]
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occured: {e}")
|
||||
|
||||
|
||||
def test_split_embedding_by_shape_fails_with_shape_value_error():
|
||||
data = [
|
||||
{
|
||||
"shape": [2],
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
}
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
TritonEmbeddingConfig.split_embedding_by_shape(
|
||||
data[0]["data"], data[0]["shape"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_completion_triton_generate_api(stream):
|
||||
try:
|
||||
mock_response = MagicMock()
|
||||
if stream:
|
||||
def mock_iter_lines():
|
||||
mock_output = ''.join([
|
||||
'data: {"model_name":"ensemble","model_version":"1","sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":"' + t + '"}\n\n'
|
||||
for t in ["I", " am", " an", " AI", " assistant"]
|
||||
])
|
||||
for out in mock_output.split('\n'):
|
||||
yield out
|
||||
mock_response.iter_lines = mock_iter_lines
|
||||
else:
|
||||
def return_val():
|
||||
return {
|
||||
"text_output": "I am an AI assistant",
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
response = litellm.completion(
|
||||
model="triton/llama-3-8b-instruct",
|
||||
messages=[{"role": "user", "content": "who are u?"}],
|
||||
max_tokens=10,
|
||||
timeout=5,
|
||||
api_base="http://localhost:8000/generate",
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# Verify the call was made
|
||||
mock_post.assert_called_once()
|
||||
|
||||
# Get the arguments passed to the post request
|
||||
print("call args", mock_post.call_args)
|
||||
call_kwargs = mock_post.call_args.kwargs # Access kwargs directly
|
||||
|
||||
# Verify URL
|
||||
if stream:
|
||||
assert call_kwargs["url"] == "http://localhost:8000/generate_stream"
|
||||
else:
|
||||
assert call_kwargs["url"] == "http://localhost:8000/generate"
|
||||
|
||||
# Parse the request data from the JSON string
|
||||
request_data = json.loads(call_kwargs["data"])
|
||||
|
||||
# Verify request data
|
||||
assert request_data["text_input"] == "who are u?"
|
||||
assert request_data["parameters"]["max_tokens"] == 10
|
||||
|
||||
# Verify response
|
||||
if stream:
|
||||
tokens = ["I", " am", " an", " AI", " assistant", None]
|
||||
idx = 0
|
||||
for chunk in response:
|
||||
assert chunk.choices[0].delta.content == tokens[idx]
|
||||
idx += 1
|
||||
assert idx == len(tokens)
|
||||
else:
|
||||
assert response.choices[0].message.content == "I am an AI assistant"
|
||||
|
||||
except Exception as e:
|
||||
print("exception", e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_triton_infer_api():
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
mock_response = MagicMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"model_name": "basketgpt",
|
||||
"model_version": "2",
|
||||
"outputs": [
|
||||
{
|
||||
"name": "text_output",
|
||||
"datatype": "BYTES",
|
||||
"shape": [1],
|
||||
"data": [
|
||||
"0004900005024 0004900006774 0004900005024 0004900005027 0004900005026 0004900005025 0004900005027 0004900005024 0004900006774 0004900005027"
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "debug_probs",
|
||||
"datatype": "FP32",
|
||||
"shape": [0],
|
||||
"data": [],
|
||||
},
|
||||
{
|
||||
"name": "debug_tokens",
|
||||
"datatype": "BYTES",
|
||||
"shape": [0],
|
||||
"data": [],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
response = litellm.completion(
|
||||
model="triton/llama-3-8b-instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "0004900005025 0004900005026 0004900005027",
|
||||
}
|
||||
],
|
||||
api_base="http://localhost:8000/infer",
|
||||
)
|
||||
|
||||
print("litellm response", response.model_dump_json(indent=4))
|
||||
|
||||
# Verify the call was made
|
||||
mock_post.assert_called_once()
|
||||
|
||||
# Get the arguments passed to the post request
|
||||
call_kwargs = mock_post.call_args.kwargs
|
||||
|
||||
# Verify URL
|
||||
assert call_kwargs["url"] == "http://localhost:8000/infer"
|
||||
|
||||
# Parse the request data from the JSON string
|
||||
request_data = json.loads(call_kwargs["data"])
|
||||
|
||||
# Verify request matches expected Triton format
|
||||
assert request_data["inputs"][0]["name"] == "text_input"
|
||||
assert request_data["inputs"][0]["shape"] == [1]
|
||||
assert request_data["inputs"][0]["datatype"] == "BYTES"
|
||||
assert request_data["inputs"][0]["data"] == [
|
||||
"0004900005025 0004900005026 0004900005027"
|
||||
]
|
||||
|
||||
assert request_data["inputs"][1]["shape"] == [1]
|
||||
assert request_data["inputs"][1]["datatype"] == "INT32"
|
||||
assert request_data["inputs"][1]["data"] == [20]
|
||||
|
||||
# Verify response format matches expected completion format
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== "0004900005024 0004900006774 0004900005024 0004900005027 0004900005026 0004900005025 0004900005027 0004900005024 0004900006774 0004900005027"
|
||||
)
|
||||
assert response.choices[0].finish_reason == "stop"
|
||||
assert response.choices[0].index == 0
|
||||
assert response.object == "chat.completion"
|
||||
|
||||
except Exception as e:
|
||||
print("exception", e)
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triton_embeddings():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.aembedding(
|
||||
model="triton/my-triton-model",
|
||||
api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings",
|
||||
input=["good morning from litellm"],
|
||||
)
|
||||
print(f"response: {response}")
|
||||
|
||||
# stubbed endpoint is setup to return this
|
||||
assert response.data[0]["embedding"] == [0.1, 0.2]
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
|
||||
def test_triton_generate_raw_request():
|
||||
from litellm.utils import return_raw_request
|
||||
from litellm.types.utils import CallTypes
|
||||
try:
|
||||
kwargs = {
|
||||
"model": "triton/llama-3-8b-instruct",
|
||||
"messages": [{"role": "user", "content": "who are u?"}],
|
||||
"api_base": "http://localhost:8000/generate",
|
||||
}
|
||||
raw_request = return_raw_request(endpoint=CallTypes.completion, kwargs=kwargs)
|
||||
print("raw_request", raw_request)
|
||||
assert raw_request is not None
|
||||
assert "bad_words" not in json.dumps(raw_request["raw_request_body"])
|
||||
assert "stop_words" not in json.dumps(raw_request["raw_request_body"])
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
Reference in New Issue
Block a user