Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,78 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
import pytest
|
||||
import base64
|
||||
import httpx
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||
|
||||
titan_embedding_response = {
|
||||
"embedding": [0.1, 0.2, 0.3],
|
||||
"inputTextTokenCount": 10
|
||||
}
|
||||
|
||||
cohere_embedding_response = {
|
||||
"embeddings": [[0.1, 0.2, 0.3]],
|
||||
"inputTextTokenCount": 10
|
||||
}
|
||||
|
||||
img_base_64 = ""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,input_type,embed_response",
|
||||
[
|
||||
("bedrock/amazon.titan-embed-text-v1", "text", titan_embedding_response), # V1 text model
|
||||
("bedrock/amazon.titan-embed-text-v2:0", "text", titan_embedding_response), # V2 text model
|
||||
("bedrock/amazon.titan-embed-image-v1", "image", titan_embedding_response), # Image model
|
||||
("bedrock/cohere.embed-english-v3", "text", cohere_embedding_response), # Cohere English
|
||||
("bedrock/cohere.embed-multilingual-v3", "text", cohere_embedding_response), # Cohere Multilingual
|
||||
],
|
||||
)
|
||||
def test_bedrock_embedding_models(model, input_type, embed_response):
|
||||
"""Test embedding functionality for all Bedrock models with different input types"""
|
||||
litellm.set_verbose = True
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = json.dumps(embed_response)
|
||||
mock_response.json = lambda: json.loads(mock_response.text)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Prepare input based on type
|
||||
input_data = img_base_64 if input_type == "image" else "Hello world from litellm"
|
||||
|
||||
try:
|
||||
response = litellm.embedding(
|
||||
model=model,
|
||||
input=input_data,
|
||||
client=client,
|
||||
aws_region_name="us-west-2",
|
||||
aws_bedrock_runtime_endpoint="https://bedrock-runtime.us-west-2.amazonaws.com",
|
||||
)
|
||||
|
||||
# Verify response structure
|
||||
assert isinstance(response, litellm.EmbeddingResponse)
|
||||
print(response.data)
|
||||
assert isinstance(response.data[0]['embedding'], list)
|
||||
assert len(response.data[0]['embedding']) == 3 # Based on mock response
|
||||
|
||||
# Fetch request body
|
||||
request_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
|
||||
# Verify AWS params are not in request body
|
||||
aws_params = ["aws_region_name", "aws_bedrock_runtime_endpoint"]
|
||||
for param in aws_params:
|
||||
assert param not in request_data, f"AWS param {param} should not be in request body"
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
Reference in New Issue
Block a user