Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
# Ensure the project root is on the import path so `litellm` can be imported when
|
||||
# tests are executed from any working directory.
|
||||
sys.path.insert(0, os.path.abspath("../../../../../.."))
|
||||
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.anthropic_claude3_transformation import (
|
||||
AmazonAnthropicClaudeConfig,
|
||||
)
|
||||
|
||||
|
||||
def test_get_supported_params_thinking():
|
||||
config = AmazonAnthropicClaudeConfig()
|
||||
params = config.get_supported_openai_params(
|
||||
model="anthropic.claude-sonnet-4-20250514-v1:0"
|
||||
)
|
||||
assert "thinking" in params
|
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,169 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from litellm.llms.bedrock.chat.invoke_handler import AWSEventStreamDecoder
|
||||
|
||||
|
||||
def test_transform_thinking_blocks_with_redacted_content():
|
||||
thinking_block = {"redactedContent": "This is a redacted content"}
|
||||
decoder = AWSEventStreamDecoder(model="test")
|
||||
transformed_thinking_blocks = decoder.translate_thinking_blocks(thinking_block)
|
||||
assert len(transformed_thinking_blocks) == 1
|
||||
assert transformed_thinking_blocks[0]["type"] == "redacted_thinking"
|
||||
assert transformed_thinking_blocks[0]["data"] == "This is a redacted content"
|
||||
|
||||
|
||||
def test_transform_tool_calls_index():
|
||||
chunks = [
|
||||
{
|
||||
"delta": {"text": "Certainly! I can help you with the"},
|
||||
"contentBlockIndex": 0,
|
||||
},
|
||||
{
|
||||
"delta": {"text": " current weather and time in Tokyo."},
|
||||
"contentBlockIndex": 0,
|
||||
},
|
||||
{"delta": {"text": " To get this information, I'll"}, "contentBlockIndex": 0},
|
||||
{"delta": {"text": " need to use two"}, "contentBlockIndex": 0},
|
||||
{"delta": {"text": " different tools: one"}, "contentBlockIndex": 0},
|
||||
{"delta": {"text": " for the weather and one for"}, "contentBlockIndex": 0},
|
||||
{"delta": {"text": " the time. Let me fetch"}, "contentBlockIndex": 0},
|
||||
{"delta": {"text": " that data for you."}, "contentBlockIndex": 0},
|
||||
{
|
||||
"start": {
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_JX1wqyUvRjyTcVSg_6-JwA",
|
||||
"name": "Weather_Tool",
|
||||
}
|
||||
},
|
||||
"contentBlockIndex": 1,
|
||||
},
|
||||
{"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 1},
|
||||
{"delta": {"toolUse": {"input": '{"locatio'}}, "contentBlockIndex": 1},
|
||||
{"delta": {"toolUse": {"input": 'n": "Toky'}}, "contentBlockIndex": 1},
|
||||
{"delta": {"toolUse": {"input": 'o"}'}}, "contentBlockIndex": 1},
|
||||
{
|
||||
"start": {
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_rxDBNjDMQ-mqA-YOp9_3cQ",
|
||||
"name": "Query_Time_Tool",
|
||||
}
|
||||
},
|
||||
"contentBlockIndex": 2,
|
||||
},
|
||||
{"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 2},
|
||||
{"delta": {"toolUse": {"input": '{"locati'}}, "contentBlockIndex": 2},
|
||||
{"delta": {"toolUse": {"input": 'on"'}}, "contentBlockIndex": 2},
|
||||
{"delta": {"toolUse": {"input": ': "Tokyo"}'}}, "contentBlockIndex": 2},
|
||||
{"stopReason": "tool_use"},
|
||||
]
|
||||
decoder = AWSEventStreamDecoder(model="test")
|
||||
parsed_chunks = []
|
||||
for chunk in chunks:
|
||||
parsed_chunk = decoder._chunk_parser(chunk)
|
||||
parsed_chunks.append(parsed_chunk)
|
||||
tool_call_chunks1 = parsed_chunks[8:12]
|
||||
tool_call_chunks2 = parsed_chunks[13:17]
|
||||
for tool_call_hunk in tool_call_chunks1:
|
||||
tool_call_hunk_dict = tool_call_hunk.model_dump()
|
||||
for tool_call in tool_call_hunk_dict["choices"][0]["delta"]["tool_calls"]:
|
||||
assert tool_call["index"] == 0
|
||||
for tool_call_hunk in tool_call_chunks2:
|
||||
tool_call_hunk_dict = tool_call_hunk.model_dump()
|
||||
for tool_call in tool_call_hunk_dict["choices"][0]["delta"]["tool_calls"]:
|
||||
assert tool_call["index"] == 1
|
||||
|
||||
|
||||
def test_transform_tool_calls_index_with_optional_arg_func():
|
||||
chunks = [
|
||||
{
|
||||
"contentBlockIndex": 0,
|
||||
"delta": {"text": "To"},
|
||||
"p": "abcdefghijklmnopqrstuv",
|
||||
},
|
||||
{
|
||||
"contentBlockIndex": 0,
|
||||
"delta": {"text": " get the current time, I"},
|
||||
"p": "abcdefghijklmnopqrstuvwxyzABCD",
|
||||
},
|
||||
{
|
||||
"contentBlockIndex": 0,
|
||||
"delta": {"text": ' can use the "get_time"'},
|
||||
"p": "abcdefghijkl",
|
||||
},
|
||||
{
|
||||
"contentBlockIndex": 0,
|
||||
"delta": {"text": " function. Since the user"},
|
||||
"p": "abcdefghijkl",
|
||||
},
|
||||
{
|
||||
"contentBlockIndex": 0,
|
||||
"delta": {"text": " didn't specify whether"},
|
||||
"p": "abcdefghijklmnopqrstuvw",
|
||||
},
|
||||
{
|
||||
"contentBlockIndex": 0,
|
||||
"delta": {"text": " they want UTC time or local time,"},
|
||||
"p": "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV",
|
||||
},
|
||||
{
|
||||
"contentBlockIndex": 0,
|
||||
"delta": {"text": " I'll assume they"},
|
||||
"p": "abcdefghijkl",
|
||||
},
|
||||
{
|
||||
"contentBlockIndex": 0,
|
||||
"delta": {"text": " want the local time. Here's"},
|
||||
"p": "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN",
|
||||
},
|
||||
{
|
||||
"contentBlockIndex": 0,
|
||||
"delta": {"text": " how I"},
|
||||
"p": "abcdefghijklmnopqrstuvw",
|
||||
},
|
||||
{
|
||||
"contentBlockIndex": 0,
|
||||
"delta": {"text": "'ll make the function call:"},
|
||||
"p": "abcdefghijklmnopqrstuvwxyzAB",
|
||||
},
|
||||
{
|
||||
"contentBlockIndex": 0,
|
||||
"p": "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ",
|
||||
},
|
||||
{
|
||||
"contentBlockIndex": 1,
|
||||
"p": "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNO",
|
||||
"start": {
|
||||
"toolUse": {
|
||||
"name": "get_time",
|
||||
"toolUseId": "tooluse_htgmgeJATsKTl4s_LW77sQ",
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"contentBlockIndex": 1,
|
||||
"delta": {"toolUse": {"input": ""}},
|
||||
"p": "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV",
|
||||
},
|
||||
{"contentBlockIndex": 1, "p": "abcdefghijklmnopqrstuvw"},
|
||||
{"p": "abcdefghijklmnopqrstuvwxyzABCDEFGHIJK", "stopReason": "tool_use"},
|
||||
]
|
||||
decoder = AWSEventStreamDecoder(model="test")
|
||||
parsed_chunks = []
|
||||
for chunk in chunks:
|
||||
parsed_chunk = decoder._chunk_parser(chunk)
|
||||
parsed_chunks.append(parsed_chunk)
|
||||
tool_call_chunks = parsed_chunks[11:14]
|
||||
for tool_call_hunk in tool_call_chunks:
|
||||
tool_call_hunk_dict = tool_call_hunk.model_dump()
|
||||
for tool_call in tool_call_hunk_dict["choices"][0]["delta"]["tool_calls"]:
|
||||
assert tool_call["index"] == 0
|
@@ -0,0 +1,32 @@
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.amazon_mistral_transformation import (
|
||||
AmazonMistralConfig,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
||||
def test_mistral_get_outputText():
|
||||
# Set initial model response with arbitrary finish reason
|
||||
model_response = ModelResponse()
|
||||
model_response.choices[0].finish_reason = "None"
|
||||
|
||||
# Models like pixtral will return a completion with the openai format.
|
||||
mock_json_with_choices = {
|
||||
"choices": [{"message": {"content": "Hello!"}, "finish_reason": "stop"}]
|
||||
}
|
||||
|
||||
outputText = AmazonMistralConfig.get_outputText(
|
||||
completion_response=mock_json_with_choices, model_response=model_response
|
||||
)
|
||||
|
||||
assert outputText == "Hello!"
|
||||
assert model_response.choices[0].finish_reason == "stop"
|
||||
|
||||
# Other models might return a completion behind "outputs"
|
||||
mock_json_with_output = {"outputs": [{"text": "Hi!", "stop_reason": "finish"}]}
|
||||
|
||||
outputText = AmazonMistralConfig.get_outputText(
|
||||
completion_response=mock_json_with_output, model_response=model_response
|
||||
)
|
||||
|
||||
assert outputText == "Hi!"
|
||||
assert model_response.choices[0].finish_reason == "finish"
|
@@ -0,0 +1,153 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../../../../..")) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||
|
||||
# Mock responses for different embedding models
|
||||
titan_embedding_response = {
|
||||
"embedding": [0.1, 0.2, 0.3],
|
||||
"inputTextTokenCount": 10
|
||||
}
|
||||
|
||||
cohere_embedding_response = {
|
||||
"embeddings": [[0.1, 0.2, 0.3]],
|
||||
"inputTextTokenCount": 10
|
||||
}
|
||||
|
||||
# Test data
|
||||
test_input = "Hello world from litellm"
|
||||
test_image_base64 = "data:image/png,test_image_base64_data"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,input_type,embed_response",
|
||||
[
|
||||
("bedrock/amazon.titan-embed-text-v1", "text", titan_embedding_response),
|
||||
("bedrock/amazon.titan-embed-text-v2:0", "text", titan_embedding_response),
|
||||
("bedrock/amazon.titan-embed-image-v1", "image", titan_embedding_response),
|
||||
("bedrock/cohere.embed-english-v3", "text", cohere_embedding_response),
|
||||
("bedrock/cohere.embed-multilingual-v3", "text", cohere_embedding_response),
|
||||
],
|
||||
)
|
||||
def test_bedrock_embedding_with_api_key_bearer_token(model, input_type, embed_response):
|
||||
"""Test embedding functionality with bearer token authentication"""
|
||||
litellm.set_verbose = True
|
||||
client = HTTPHandler()
|
||||
test_api_key = "test-bearer-token-12345"
|
||||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = json.dumps(embed_response)
|
||||
mock_response.json = lambda: json.loads(mock_response.text)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
input_data = test_image_base64 if input_type == "image" else test_input
|
||||
|
||||
response = litellm.embedding(
|
||||
model=model,
|
||||
input=input_data,
|
||||
client=client,
|
||||
aws_region_name="us-east-1",
|
||||
aws_bedrock_runtime_endpoint="https://bedrock-runtime.us-east-1.amazonaws.com",
|
||||
api_key=test_api_key
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.EmbeddingResponse)
|
||||
assert isinstance(response.data[0]['embedding'], list)
|
||||
assert len(response.data[0]['embedding']) == 3 # Based on mock response
|
||||
|
||||
headers = mock_post.call_args.kwargs.get("headers", {})
|
||||
assert "Authorization" in headers
|
||||
assert headers["Authorization"] == f"Bearer {test_api_key}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,input_type,embed_response",
|
||||
[
|
||||
("bedrock/amazon.titan-embed-text-v1", "text", titan_embedding_response),
|
||||
],
|
||||
)
|
||||
def test_bedrock_embedding_with_env_variable_bearer_token(model, input_type, embed_response):
|
||||
"""Test embedding functionality with bearer token from environment variable"""
|
||||
litellm.set_verbose = True
|
||||
client = HTTPHandler()
|
||||
test_api_key = "env-bearer-token-12345"
|
||||
|
||||
with patch.dict(os.environ, {"AWS_BEARER_TOKEN_BEDROCK": test_api_key}), \
|
||||
patch.object(client, "post") as mock_post:
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = json.dumps(embed_response)
|
||||
mock_response.json = lambda: json.loads(mock_response.text)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
response = litellm.embedding(
|
||||
model=model,
|
||||
input=test_input,
|
||||
client=client,
|
||||
aws_region_name="us-west-2",
|
||||
aws_bedrock_runtime_endpoint="https://bedrock-runtime.us-west-2.amazonaws.com",
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.EmbeddingResponse)
|
||||
headers = mock_post.call_args.kwargs.get("headers", {})
|
||||
assert "Authorization" in headers
|
||||
assert headers["Authorization"] == f"Bearer {test_api_key}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_bedrock_embedding_with_bearer_token():
|
||||
"""Test async embedding functionality with bearer token authentication"""
|
||||
litellm.set_verbose = True
|
||||
client = AsyncHTTPHandler()
|
||||
test_api_key = "async-bearer-token-12345"
|
||||
model = "bedrock/amazon.titan-embed-text-v1"
|
||||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = json.dumps(titan_embedding_response)
|
||||
mock_response.json = Mock(return_value=titan_embedding_response)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
response = await litellm.aembedding(
|
||||
model=model,
|
||||
input=test_input,
|
||||
client=client,
|
||||
aws_region_name="us-west-2",
|
||||
aws_bedrock_runtime_endpoint="https://bedrock-runtime.us-west-2.amazonaws.com",
|
||||
api_key=test_api_key
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.EmbeddingResponse)
|
||||
|
||||
headers = mock_post.call_args.kwargs.get("headers", {})
|
||||
assert "Authorization" in headers
|
||||
assert headers["Authorization"] == f"Bearer {test_api_key}"
|
||||
|
||||
|
||||
def test_bedrock_embedding_with_sigv4():
|
||||
"""Test embedding falls back to SigV4 auth when no bearer token is provided"""
|
||||
litellm.set_verbose = True
|
||||
model = "bedrock/amazon.titan-embed-text-v1"
|
||||
|
||||
with patch("litellm.llms.bedrock.embed.embedding.BedrockEmbedding.embeddings") as mock_bedrock_embed:
|
||||
mock_embedding_response = litellm.EmbeddingResponse()
|
||||
mock_embedding_response.data = [{"embedding": [0.1, 0.2, 0.3]}]
|
||||
mock_bedrock_embed.return_value = mock_embedding_response
|
||||
|
||||
response = litellm.embedding(
|
||||
model=model,
|
||||
input=test_input,
|
||||
aws_region_name="us-west-2",
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.EmbeddingResponse)
|
||||
mock_bedrock_embed.assert_called_once()
|
@@ -0,0 +1,75 @@
|
||||
import pytest
|
||||
from litellm.llms.bedrock.image.amazon_nova_canvas_transformation import AmazonNovaCanvasConfig
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
def test_transform_request_body_text_to_image():
|
||||
params = {
|
||||
"imageGenerationConfig": {
|
||||
"cfgScale": 7,
|
||||
"seed": 42,
|
||||
"quality": "standard",
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"numberOfImages": 1,
|
||||
"textToImageParams": {
|
||||
"negativeText": "blurry"
|
||||
}
|
||||
}
|
||||
}
|
||||
req = AmazonNovaCanvasConfig.transform_request_body("cat", params.copy())
|
||||
assert isinstance(req, dict)
|
||||
assert "textToImageParams" in req
|
||||
assert req["textToImageParams"]["text"] == "cat"
|
||||
assert req["imageGenerationConfig"]["width"] == 512
|
||||
|
||||
def test_transform_request_body_color_guided():
|
||||
params = {
|
||||
"taskType": "COLOR_GUIDED_GENERATION",
|
||||
"imageGenerationConfig": {
|
||||
"cfgScale": 7,
|
||||
"seed": 42,
|
||||
"quality": "standard",
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"numberOfImages": 1,
|
||||
"colorGuidedGenerationParams": {
|
||||
"colors": ["#FFFFFF"],
|
||||
"referenceImage": "img",
|
||||
"negativeText": "blurry"
|
||||
}
|
||||
}
|
||||
}
|
||||
req = AmazonNovaCanvasConfig.transform_request_body("cat", params.copy())
|
||||
assert "colorGuidedGenerationParams" in req
|
||||
assert req["colorGuidedGenerationParams"]["text"] == "cat"
|
||||
assert req["imageGenerationConfig"]["width"] == 512
|
||||
|
||||
def test_transform_request_body_inpainting():
|
||||
params = {
|
||||
"taskType": "INPAINTING",
|
||||
"imageGenerationConfig": {
|
||||
"cfgScale": 7,
|
||||
"seed": 42,
|
||||
"quality": "standard",
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"numberOfImages": 1,
|
||||
"inpaintingParams": {
|
||||
"maskImage": "mask",
|
||||
"inputImage": "input",
|
||||
"negativeText": "blurry"
|
||||
}
|
||||
}
|
||||
}
|
||||
req = AmazonNovaCanvasConfig.transform_request_body("cat", params.copy())
|
||||
assert "inpaintingParams" in req
|
||||
assert req["inpaintingParams"]["text"] == "cat"
|
||||
assert req["imageGenerationConfig"]["width"] == 512
|
||||
|
||||
def test_transform_response_dict_to_openai_response():
|
||||
response_dict = {"images": ["b64img1", "b64img2"]}
|
||||
model_response = ImageResponse()
|
||||
result = AmazonNovaCanvasConfig.transform_response_dict_to_openai_response(model_response, response_dict)
|
||||
assert hasattr(result, "data")
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].b64_json == "b64img1"
|
@@ -0,0 +1,20 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from litellm.llms.bedrock.image.amazon_stability3_transformation import (
|
||||
AmazonStability3Config,
|
||||
)
|
||||
|
||||
|
||||
def test_stability_image_core_is_v3_model():
|
||||
model = "stability.stable-image-core-v1:1"
|
||||
assert AmazonStability3Config._is_stability_3_model(model)
|
@@ -0,0 +1,130 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../../../../..")) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||
|
||||
# Mock response for Bedrock image generation
|
||||
mock_image_response = {
|
||||
"images": ["base64_encoded_image_data"],
|
||||
"error": None
|
||||
}
|
||||
|
||||
class TestBedrockImageGeneration:
|
||||
def test_image_generation_with_api_key_bearer_token(self):
|
||||
"""Test image generation with bearer token authentication"""
|
||||
litellm.set_verbose = True
|
||||
test_api_key = "test-bearer-token-12345"
|
||||
model = "bedrock/stability.sd3-large-v1:0"
|
||||
prompt = "A cute baby sea otter"
|
||||
|
||||
with patch("litellm.llms.bedrock.image.image_handler.BedrockImageGeneration.image_generation") as mock_bedrock_image_gen:
|
||||
# Setup mock response
|
||||
mock_image_response_obj = litellm.ImageResponse()
|
||||
mock_image_response_obj.data = [{"url": "https://example.com/image.jpg"}]
|
||||
mock_bedrock_image_gen.return_value = mock_image_response_obj
|
||||
|
||||
response = litellm.image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
aws_region_name="us-west-2",
|
||||
api_key=test_api_key
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert len(response.data) > 0
|
||||
|
||||
mock_bedrock_image_gen.assert_called_once()
|
||||
for call in mock_bedrock_image_gen.call_args_list:
|
||||
if "headers" in call.kwargs:
|
||||
headers = call.kwargs["headers"]
|
||||
if "Authorization" in headers and headers["Authorization"] == f"Bearer {test_api_key}":
|
||||
break
|
||||
|
||||
def test_image_generation_with_env_variable_bearer_token(self, monkeypatch):
|
||||
"""Test image generation with bearer token from environment variable"""
|
||||
litellm.set_verbose = True
|
||||
test_api_key = "env-bearer-token-12345"
|
||||
model = "bedrock/stability.sd3-large-v1:0"
|
||||
prompt = "A cute baby sea otter"
|
||||
|
||||
# Mock the environment variable
|
||||
with patch.dict(os.environ, {"AWS_BEARER_TOKEN_BEDROCK": test_api_key}), \
|
||||
patch("litellm.llms.bedrock.image.image_handler.BedrockImageGeneration.image_generation") as mock_bedrock_image_gen:
|
||||
|
||||
mock_image_response_obj = litellm.ImageResponse()
|
||||
mock_image_response_obj.data = [{"url": "https://example.com/image.jpg"}]
|
||||
mock_bedrock_image_gen.return_value = mock_image_response_obj
|
||||
|
||||
response = litellm.image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
aws_region_name="us-west-2"
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert len(response.data) > 0
|
||||
|
||||
mock_bedrock_image_gen.assert_called_once()
|
||||
for call in mock_bedrock_image_gen.call_args_list:
|
||||
if "headers" in call.kwargs:
|
||||
headers = call.kwargs["headers"]
|
||||
if "Authorization" in headers and headers["Authorization"] == f"Bearer {test_api_key}":
|
||||
break
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_image_generation_with_bearer_token(self):
|
||||
"""Test async image generation with bearer token authentication"""
|
||||
litellm.set_verbose = True
|
||||
test_api_key = "async-bearer-token-12345"
|
||||
model = "bedrock/stability.sd3-large-v1:0"
|
||||
prompt = "A cute baby sea otter"
|
||||
|
||||
with patch("litellm.llms.bedrock.image.image_handler.BedrockImageGeneration.async_image_generation") as mock_async_bedrock_image_gen:
|
||||
mock_image_response_obj = litellm.ImageResponse()
|
||||
mock_image_response_obj.data = [{"url": "https://example.com/image.jpg"}]
|
||||
mock_async_bedrock_image_gen.return_value = mock_image_response_obj
|
||||
|
||||
# Call async image generation with api_key parameter
|
||||
response = await litellm.aimage_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
aws_region_name="us-west-2",
|
||||
api_key=test_api_key
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert len(response.data) > 0
|
||||
|
||||
mock_async_bedrock_image_gen.assert_called_once()
|
||||
for call in mock_async_bedrock_image_gen.call_args_list:
|
||||
if "headers" in call.kwargs:
|
||||
headers = call.kwargs["headers"]
|
||||
if "Authorization" in headers and headers["Authorization"] == f"Bearer {test_api_key}":
|
||||
break
|
||||
|
||||
def test_image_generation_with_sigv4(self):
|
||||
"""Test image generation falls back to SigV4 auth when no bearer token is provided"""
|
||||
litellm.set_verbose = True
|
||||
model = "bedrock/stability.sd3-large-v1:0"
|
||||
prompt = "A cute baby sea otter"
|
||||
|
||||
with patch("litellm.llms.bedrock.image.image_handler.BedrockImageGeneration.image_generation") as mock_bedrock_image_gen:
|
||||
mock_image_response_obj = litellm.ImageResponse()
|
||||
mock_image_response_obj.data = [{"url": "https://example.com/image.jpg"}]
|
||||
mock_bedrock_image_gen.return_value = mock_image_response_obj
|
||||
|
||||
response = litellm.image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
aws_region_name="us-west-2"
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert len(response.data) > 0
|
||||
mock_bedrock_image_gen.assert_called_once()
|
@@ -0,0 +1,272 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.llms.bedrock.chat.invoke_agent.transformation import (
|
||||
AmazonInvokeAgentConfig,
|
||||
)
|
||||
from litellm.types.llms.bedrock_invoke_agents import (
|
||||
InvokeAgentEvent,
|
||||
InvokeAgentEventHeaders,
|
||||
InvokeAgentUsage,
|
||||
)
|
||||
from litellm.types.utils import Message, ModelResponse, Usage
|
||||
|
||||
|
||||
class TestAmazonInvokeAgentConfig:
|
||||
"""Test suite for AmazonInvokeAgentConfig methods"""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
"""Create a test instance of AmazonInvokeAgentConfig"""
|
||||
return AmazonInvokeAgentConfig()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages(self):
|
||||
"""Sample messages for testing"""
|
||||
return [
|
||||
{"role": "user", "content": "Hello, how can you help me?"},
|
||||
{"role": "assistant", "content": "I can help with various tasks."},
|
||||
{"role": "user", "content": "What is the weather like?"},
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_events(self):
|
||||
"""Sample events for testing event parsing"""
|
||||
return [
|
||||
{
|
||||
"headers": {"event_type": "chunk"},
|
||||
"payload": {
|
||||
"bytes": base64.b64encode("Hello ".encode("utf-8")).decode("utf-8")
|
||||
},
|
||||
},
|
||||
{
|
||||
"headers": {"event_type": "chunk"},
|
||||
"payload": {
|
||||
"bytes": base64.b64encode("world!".encode("utf-8")).decode("utf-8")
|
||||
},
|
||||
},
|
||||
{
|
||||
"headers": {"event_type": "trace"},
|
||||
"payload": {
|
||||
"trace": {
|
||||
"preProcessingTrace": {
|
||||
"modelInvocationOutput": {
|
||||
"metadata": {
|
||||
"usage": {"inputTokens": 10, "outputTokens": 20}
|
||||
}
|
||||
}
|
||||
},
|
||||
"orchestrationTrace": {
|
||||
"modelInvocationInput": {
|
||||
"foundationModel": "anthropic.claude-v2"
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def test_get_agent_id_and_alias_id_valid(self, config):
|
||||
"""Test parsing valid agent model string"""
|
||||
model = "agent/L1RT58GYRW/MFPSBCXYTW"
|
||||
agent_id, agent_alias_id = config._get_agent_id_and_alias_id(model)
|
||||
|
||||
assert agent_id == "L1RT58GYRW"
|
||||
assert agent_alias_id == "MFPSBCXYTW"
|
||||
|
||||
def test_get_agent_id_and_alias_id_invalid_format(self, config):
|
||||
"""Test parsing invalid agent model string"""
|
||||
invalid_models = [
|
||||
"invalid/L1RT58GYRW/MFPSBCXYTW", # Wrong prefix
|
||||
"agent/L1RT58GYRW", # Missing alias
|
||||
"agent/L1RT58GYRW/MFPSBCXYTW/extra", # Too many parts
|
||||
"L1RT58GYRW/MFPSBCXYTW", # Missing prefix
|
||||
]
|
||||
|
||||
for invalid_model in invalid_models:
|
||||
with pytest.raises(ValueError, match="Invalid model format"):
|
||||
config._get_agent_id_and_alias_id(invalid_model)
|
||||
|
||||
@patch(
|
||||
"litellm.llms.bedrock.chat.invoke_agent.transformation.convert_content_list_to_str"
|
||||
)
|
||||
def test_transform_request(self, mock_convert, config, sample_messages):
|
||||
"""Test transform_request method"""
|
||||
mock_convert.return_value = "What is the weather like?"
|
||||
|
||||
model = "agent/TEST123/ALIAS456"
|
||||
optional_params = {}
|
||||
litellm_params = {}
|
||||
headers = {}
|
||||
|
||||
result = config.transform_request(
|
||||
model, sample_messages, optional_params, litellm_params, headers
|
||||
)
|
||||
|
||||
expected = {
|
||||
"inputText": "What is the weather like?",
|
||||
"enableTrace": True,
|
||||
}
|
||||
assert result == expected
|
||||
mock_convert.assert_called_once_with(sample_messages[-1])
|
||||
|
||||
def test_extract_response_content(self, config, sample_events):
|
||||
"""Test _extract_response_content method"""
|
||||
result = config._extract_response_content(sample_events)
|
||||
assert result == "Hello world!"
|
||||
|
||||
def test_extract_response_content_empty_events(self, config):
|
||||
"""Test _extract_response_content with empty events"""
|
||||
result = config._extract_response_content([])
|
||||
assert result == ""
|
||||
|
||||
def test_extract_response_content_no_chunk_events(self, config):
|
||||
"""Test _extract_response_content with no chunk events"""
|
||||
events = [{"headers": {"event_type": "trace"}, "payload": {"some": "data"}}]
|
||||
result = config._extract_response_content(events)
|
||||
assert result == ""
|
||||
|
||||
def test_is_trace_event(self, config):
|
||||
"""Test _is_trace_event method"""
|
||||
trace_event = {"headers": {"event_type": "trace"}, "payload": {"some": "data"}}
|
||||
chunk_event = {"headers": {"event_type": "chunk"}, "payload": {"bytes": "data"}}
|
||||
invalid_event = {"headers": {"event_type": "trace"}, "payload": None}
|
||||
|
||||
assert config._is_trace_event(trace_event) is True
|
||||
assert config._is_trace_event(chunk_event) is False
|
||||
assert config._is_trace_event(invalid_event) is False
|
||||
|
||||
def test_get_trace_data(self, config):
|
||||
"""Test _get_trace_data method"""
|
||||
event = {"payload": {"trace": {"preProcessingTrace": {"some": "data"}}}}
|
||||
result = config._get_trace_data(event)
|
||||
assert result == {"preProcessingTrace": {"some": "data"}}
|
||||
|
||||
def test_get_trace_data_no_payload(self, config):
|
||||
"""Test _get_trace_data with no payload"""
|
||||
event = {"payload": None}
|
||||
result = config._get_trace_data(event)
|
||||
assert result is None
|
||||
|
||||
def test_extract_usage_info(self, config, sample_events):
|
||||
"""Test _extract_usage_info method"""
|
||||
result = config._extract_usage_info(sample_events)
|
||||
|
||||
assert result["inputTokens"] == 10
|
||||
assert result["outputTokens"] == 20
|
||||
assert result["model"] == "anthropic.claude-v2"
|
||||
|
||||
def test_extract_usage_info_empty_events(self, config):
|
||||
"""Test _extract_usage_info with empty events"""
|
||||
result = config._extract_usage_info([])
|
||||
|
||||
assert result["inputTokens"] == 0
|
||||
assert result["outputTokens"] == 0
|
||||
assert result["model"] is None
|
||||
|
||||
def test_extract_and_update_preprocessing_usage(self, config):
|
||||
"""Test _extract_and_update_preprocessing_usage method"""
|
||||
trace_data = {
|
||||
"preProcessingTrace": {
|
||||
"modelInvocationOutput": {
|
||||
"metadata": {"usage": {"inputTokens": 15, "outputTokens": 25}}
|
||||
}
|
||||
}
|
||||
}
|
||||
usage_info = {"inputTokens": 5, "outputTokens": 10, "model": None}
|
||||
|
||||
config._extract_and_update_preprocessing_usage(trace_data, usage_info)
|
||||
|
||||
assert usage_info["inputTokens"] == 20 # 5 + 15
|
||||
assert usage_info["outputTokens"] == 35 # 10 + 25
|
||||
|
||||
def test_extract_and_update_preprocessing_usage_no_data(self, config):
|
||||
"""Test _extract_and_update_preprocessing_usage with missing data"""
|
||||
trace_data = {}
|
||||
usage_info = {"inputTokens": 5, "outputTokens": 10, "model": None}
|
||||
|
||||
config._extract_and_update_preprocessing_usage(trace_data, usage_info)
|
||||
|
||||
# Should remain unchanged
|
||||
assert usage_info["inputTokens"] == 5
|
||||
assert usage_info["outputTokens"] == 10
|
||||
|
||||
def test_extract_orchestration_model(self, config):
|
||||
"""Test _extract_orchestration_model method"""
|
||||
trace_data = {
|
||||
"orchestrationTrace": {
|
||||
"modelInvocationInput": {"foundationModel": "anthropic.claude-v2"}
|
||||
}
|
||||
}
|
||||
result = config._extract_orchestration_model(trace_data)
|
||||
assert result == "anthropic.claude-v2"
|
||||
|
||||
def test_extract_orchestration_model_no_data(self, config):
|
||||
"""Test _extract_orchestration_model with missing data"""
|
||||
trace_data = {}
|
||||
result = config._extract_orchestration_model(trace_data)
|
||||
assert result is None
|
||||
|
||||
def test_build_model_response(self, config):
|
||||
"""Test _build_model_response method"""
|
||||
content = "Hello, world!"
|
||||
model = "agent/TEST123/ALIAS456"
|
||||
usage_info = {
|
||||
"inputTokens": 10,
|
||||
"outputTokens": 20,
|
||||
"model": "anthropic.claude-v2",
|
||||
}
|
||||
model_response = ModelResponse()
|
||||
|
||||
result = config._build_model_response(
|
||||
content, model, usage_info, model_response
|
||||
)
|
||||
|
||||
assert len(result.choices) == 1
|
||||
assert result.choices[0].message.content == content
|
||||
assert result.choices[0].message.role == "assistant"
|
||||
assert result.choices[0].finish_reason == "stop"
|
||||
assert result.model == "anthropic.claude-v2"
|
||||
assert hasattr(result, "usage")
|
||||
assert result.usage.prompt_tokens == 10
|
||||
assert result.usage.completion_tokens == 20
|
||||
assert result.usage.total_tokens == 30
|
||||
|
||||
@patch(
|
||||
"litellm.llms.bedrock.chat.invoke_agent.transformation.convert_content_list_to_str"
|
||||
)
|
||||
@patch.object(AmazonInvokeAgentConfig, "get_runtime_endpoint")
|
||||
@patch.object(AmazonInvokeAgentConfig, "_get_aws_region_name")
|
||||
def test_get_complete_url(self, mock_region, mock_endpoint, mock_convert, config):
|
||||
"""Test get_complete_url method"""
|
||||
mock_endpoint.return_value = (
|
||||
"https://bedrock-runtime.us-east-1.amazonaws.com",
|
||||
None,
|
||||
)
|
||||
mock_region.return_value = "us-east-1"
|
||||
|
||||
api_base = None
|
||||
api_key = None
|
||||
model = "agent/L1RT58GYRW/MFPSBCXYTW"
|
||||
optional_params = {}
|
||||
litellm_params = {}
|
||||
|
||||
result = config.get_complete_url(
|
||||
api_base, api_key, model, optional_params, litellm_params
|
||||
)
|
||||
|
||||
assert (
|
||||
"https://bedrock-runtime.us-east-1.amazonaws.com/agents/L1RT58GYRW/agentAliases/MFPSBCXYTW/sessions"
|
||||
in result
|
||||
)
|
@@ -0,0 +1,81 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
# Ensure the project root is on the import path so `litellm` can be imported when
|
||||
# tests are executed from any working directory.
|
||||
sys.path.insert(0, os.path.abspath("../../../../../.."))
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.bedrock.messages.invoke_transformations.anthropic_claude3_transformation import (
|
||||
AmazonAnthropicClaudeMessagesConfig,
|
||||
AmazonAnthropicClaudeMessagesStreamDecoder,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bedrock_sse_wrapper_encodes_dict_chunks():
|
||||
"""Verify that `bedrock_sse_wrapper` converts dictionary chunks to properly formatted Server-Sent Events and forwards non-dict chunks unchanged."""
|
||||
|
||||
cfg = AmazonAnthropicClaudeMessagesConfig()
|
||||
|
||||
async def _dummy_stream(): # type: ignore[return-type]
|
||||
yield {"type": "message_delta", "text": "hello"}
|
||||
yield b"raw-bytes"
|
||||
|
||||
# Collect all chunks returned by the wrapper
|
||||
collected: list[bytes] = []
|
||||
async for chunk in cfg.bedrock_sse_wrapper(
|
||||
_dummy_stream(),
|
||||
litellm_logging_obj=LiteLLMLoggingObj(
|
||||
model="bedrock/invoke/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
|
||||
stream=True,
|
||||
call_type="chat",
|
||||
start_time=datetime.now(),
|
||||
litellm_call_id="test_bedrock_sse_wrapper_encodes_dict_chunks",
|
||||
function_id="test_bedrock_sse_wrapper_encodes_dict_chunks",
|
||||
),
|
||||
request_body={},
|
||||
):
|
||||
collected.append(chunk)
|
||||
|
||||
assert collected, "No chunks returned from wrapper"
|
||||
|
||||
# First chunk should be SSE encoded
|
||||
first_chunk = collected[0]
|
||||
assert first_chunk.startswith(b"event: message_delta\n"), first_chunk
|
||||
assert first_chunk.endswith(b"\n\n"), first_chunk
|
||||
# Ensure the JSON payload is present in the SSE data line
|
||||
assert b'"hello"' in first_chunk # payload contains the text
|
||||
|
||||
# Second chunk should be forwarded unchanged
|
||||
assert collected[1] == b"raw-bytes"
|
||||
|
||||
|
||||
def test_chunk_parser_usage_transformation():
|
||||
"""Ensure Bedrock invocation metrics are transformed to Anthropic usage keys."""
|
||||
|
||||
decoder = AmazonAnthropicClaudeMessagesStreamDecoder(
|
||||
model="bedrock/invoke/anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
)
|
||||
|
||||
chunk = {
|
||||
"type": "message_delta",
|
||||
"amazon-bedrock-invocationMetrics": {
|
||||
"inputTokenCount": 10,
|
||||
"outputTokenCount": 5,
|
||||
},
|
||||
}
|
||||
|
||||
parsed = decoder._chunk_parser(chunk.copy()) # use copy to avoid side-effects
|
||||
|
||||
# The invocation metrics key should be removed and replaced by `usage`
|
||||
assert "amazon-bedrock-invocationMetrics" not in parsed
|
||||
assert "usage" in parsed
|
||||
assert parsed["usage"]["input_tokens"] == 10
|
||||
assert parsed["usage"]["output_tokens"] == 5
|
@@ -0,0 +1,14 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from litellm import rerank
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Test anthropic_beta header support for AWS Bedrock.
|
||||
|
||||
Tests that anthropic-beta headers are correctly processed and passed to AWS Bedrock
|
||||
for enabling beta features like 1M context window, computer use tools, etc.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
|
||||
from litellm.llms.bedrock.common_utils import get_anthropic_beta_from_headers
|
||||
from litellm.llms.bedrock.chat.converse_transformation import AmazonConverseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.anthropic_claude3_transformation import AmazonAnthropicClaudeConfig
|
||||
from litellm.llms.bedrock.messages.invoke_transformations.anthropic_claude3_transformation import AmazonAnthropicClaudeMessagesConfig
|
||||
|
||||
|
||||
class TestAnthropicBetaHeaderSupport:
|
||||
"""Test anthropic_beta header functionality across Bedrock APIs."""
|
||||
|
||||
def test_get_anthropic_beta_from_headers_empty(self):
|
||||
"""Test header extraction with no headers."""
|
||||
headers = {}
|
||||
result = get_anthropic_beta_from_headers(headers)
|
||||
assert result == []
|
||||
|
||||
def test_get_anthropic_beta_from_headers_single(self):
|
||||
"""Test header extraction with single beta header."""
|
||||
headers = {"anthropic-beta": "context-1m-2025-08-07"}
|
||||
result = get_anthropic_beta_from_headers(headers)
|
||||
assert result == ["context-1m-2025-08-07"]
|
||||
|
||||
def test_get_anthropic_beta_from_headers_multiple(self):
|
||||
"""Test header extraction with multiple comma-separated beta headers."""
|
||||
headers = {"anthropic-beta": "context-1m-2025-08-07,computer-use-2024-10-22"}
|
||||
result = get_anthropic_beta_from_headers(headers)
|
||||
assert result == ["context-1m-2025-08-07", "computer-use-2024-10-22"]
|
||||
|
||||
def test_get_anthropic_beta_from_headers_whitespace(self):
|
||||
"""Test header extraction handles whitespace correctly."""
|
||||
headers = {"anthropic-beta": " context-1m-2025-08-07 , computer-use-2024-10-22 "}
|
||||
result = get_anthropic_beta_from_headers(headers)
|
||||
assert result == ["context-1m-2025-08-07", "computer-use-2024-10-22"]
|
||||
|
||||
def test_invoke_transformation_anthropic_beta(self):
|
||||
"""Test that Invoke API transformation includes anthropic_beta in request."""
|
||||
config = AmazonAnthropicClaudeConfig()
|
||||
headers = {"anthropic-beta": "context-1m-2025-08-07,computer-use-2024-10-22"}
|
||||
|
||||
result = config.transform_request(
|
||||
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
headers=headers
|
||||
)
|
||||
|
||||
assert "anthropic_beta" in result
|
||||
assert result["anthropic_beta"] == ["context-1m-2025-08-07", "computer-use-2024-10-22"]
|
||||
|
||||
def test_converse_transformation_anthropic_beta(self):
|
||||
"""Test that Converse API transformation includes anthropic_beta in additionalModelRequestFields."""
|
||||
config = AmazonConverseConfig()
|
||||
headers = {"anthropic-beta": "context-1m-2025-08-07,interleaved-thinking-2025-05-14"}
|
||||
|
||||
result = config._transform_request_helper(
|
||||
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
system_content_blocks=[],
|
||||
optional_params={},
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
headers=headers
|
||||
)
|
||||
|
||||
assert "additionalModelRequestFields" in result
|
||||
additional_fields = result["additionalModelRequestFields"]
|
||||
assert "anthropic_beta" in additional_fields
|
||||
assert additional_fields["anthropic_beta"] == ["context-1m-2025-08-07", "interleaved-thinking-2025-05-14"]
|
||||
|
||||
def test_messages_transformation_anthropic_beta(self):
|
||||
"""Test that Messages API transformation includes anthropic_beta in request."""
|
||||
config = AmazonAnthropicClaudeMessagesConfig()
|
||||
headers = {"anthropic-beta": "output-128k-2025-02-19"}
|
||||
|
||||
result = config.transform_anthropic_messages_request(
|
||||
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
anthropic_messages_optional_request_params={"max_tokens": 100},
|
||||
litellm_params={},
|
||||
headers=headers
|
||||
)
|
||||
|
||||
assert "anthropic_beta" in result
|
||||
assert result["anthropic_beta"] == ["output-128k-2025-02-19"]
|
||||
|
||||
def test_converse_computer_use_compatibility(self):
|
||||
"""Test that user anthropic_beta headers work with computer use tools."""
|
||||
config = AmazonConverseConfig()
|
||||
headers = {"anthropic-beta": "context-1m-2025-08-07"}
|
||||
|
||||
# Computer use tools should automatically add computer-use-2024-10-22
|
||||
tools = [
|
||||
{
|
||||
"type": "computer_20241022",
|
||||
"name": "computer",
|
||||
"display_width_px": 1024,
|
||||
"display_height_px": 768
|
||||
}
|
||||
]
|
||||
|
||||
result = config._transform_request_helper(
|
||||
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
system_content_blocks=[],
|
||||
optional_params={"tools": tools},
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
headers=headers
|
||||
)
|
||||
|
||||
additional_fields = result["additionalModelRequestFields"]
|
||||
betas = additional_fields["anthropic_beta"]
|
||||
|
||||
# Should contain both user-provided and auto-added beta headers
|
||||
assert "context-1m-2025-08-07" in betas
|
||||
assert "computer-use-2024-10-22" in betas
|
||||
assert len(betas) == 2 # No duplicates
|
||||
|
||||
def test_no_anthropic_beta_headers(self):
|
||||
"""Test that transformations work correctly when no anthropic_beta headers are provided."""
|
||||
config = AmazonConverseConfig()
|
||||
headers = {}
|
||||
|
||||
result = config._transform_request_helper(
|
||||
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
system_content_blocks=[],
|
||||
optional_params={},
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
headers=headers
|
||||
)
|
||||
|
||||
additional_fields = result.get("additionalModelRequestFields", {})
|
||||
assert "anthropic_beta" not in additional_fields
|
||||
|
||||
def test_anthropic_beta_all_supported_features(self):
|
||||
"""Test that all documented beta features are properly handled."""
|
||||
supported_features = [
|
||||
"context-1m-2025-08-07",
|
||||
"computer-use-2025-01-24",
|
||||
"computer-use-2024-10-22",
|
||||
"token-efficient-tools-2025-02-19",
|
||||
"interleaved-thinking-2025-05-14",
|
||||
"output-128k-2025-02-19",
|
||||
"dev-full-thinking-2025-05-14"
|
||||
]
|
||||
|
||||
config = AmazonAnthropicClaudeConfig()
|
||||
headers = {"anthropic-beta": ",".join(supported_features)}
|
||||
|
||||
result = config.transform_request(
|
||||
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
headers=headers
|
||||
)
|
||||
|
||||
assert "anthropic_beta" in result
|
||||
assert result["anthropic_beta"] == supported_features
|
@@ -0,0 +1,481 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from botocore.credentials import Credentials
|
||||
from botocore.awsrequest import AWSRequest, AWSPreparedRequest
|
||||
import litellm
|
||||
from litellm.llms.bedrock.base_aws_llm import (
|
||||
AwsAuthError,
|
||||
BaseAWSLLM,
|
||||
Boto3CredentialsInfo,
|
||||
)
|
||||
from litellm.caching.caching import DualCache
|
||||
|
||||
# Global variable for the base_aws_llm.py file path
|
||||
|
||||
BASE_AWS_LLM_PATH = os.path.join(
|
||||
os.path.dirname(__file__), "../../../../litellm/llms/bedrock/base_aws_llm.py"
|
||||
)
|
||||
|
||||
|
||||
def test_boto3_init_tracer_wrapping():
|
||||
"""
|
||||
Test that all boto3 initializations are wrapped in tracer.trace or @tracer.wrap
|
||||
|
||||
Ensures observability of boto3 calls in litellm.
|
||||
"""
|
||||
# Get the source code of base_aws_llm.py
|
||||
with open(BASE_AWS_LLM_PATH, "r") as f:
|
||||
content = f.read()
|
||||
|
||||
# List all boto3 initialization patterns we want to check
|
||||
boto3_init_patterns = ["boto3.client", "boto3.Session"]
|
||||
|
||||
lines = content.split("\n")
|
||||
# Check each boto3 initialization is wrapped in tracer.trace
|
||||
for line_number, line in enumerate(lines, 1):
|
||||
for pattern in boto3_init_patterns:
|
||||
if pattern in line:
|
||||
# Look back up to 5 lines for decorator or trace block
|
||||
start_line = max(0, line_number - 5)
|
||||
context_lines = lines[start_line:line_number]
|
||||
|
||||
has_trace = (
|
||||
"tracer.trace" in line
|
||||
or any("tracer.trace" in prev_line for prev_line in context_lines)
|
||||
or any("@tracer.wrap" in prev_line for prev_line in context_lines)
|
||||
)
|
||||
|
||||
if not has_trace:
|
||||
print(f"\nContext for line {line_number}:")
|
||||
for i, ctx_line in enumerate(context_lines, start=start_line + 1):
|
||||
print(f"{i}: {ctx_line}")
|
||||
|
||||
assert (
|
||||
has_trace
|
||||
), f"boto3 initialization '{pattern}' on line {line_number} is not wrapped with tracer.trace or @tracer.wrap"
|
||||
|
||||
|
||||
def test_auth_functions_tracer_wrapping():
|
||||
"""
|
||||
Test that all _auth functions in base_aws_llm.py are wrapped with @tracer.wrap
|
||||
|
||||
Ensures observability of AWS authentication calls in litellm.
|
||||
"""
|
||||
# Get the source code of base_aws_llm.py
|
||||
with open(BASE_AWS_LLM_PATH, "r") as f:
|
||||
content = f.read()
|
||||
|
||||
lines = content.split("\n")
|
||||
# Check each line for _auth function definitions
|
||||
for line_number, line in enumerate(lines, 1):
|
||||
if line.strip().startswith("def _auth_"):
|
||||
# Look back up to 2 lines for the @tracer.wrap decorator
|
||||
start_line = max(0, line_number - 2)
|
||||
context_lines = lines[start_line:line_number]
|
||||
|
||||
has_tracer_wrap = any(
|
||||
"@tracer.wrap" in prev_line for prev_line in context_lines
|
||||
)
|
||||
|
||||
if not has_tracer_wrap:
|
||||
print(f"\nContext for line {line_number}:")
|
||||
for i, ctx_line in enumerate(context_lines, start=start_line + 1):
|
||||
print(f"{i}: {ctx_line}")
|
||||
|
||||
assert (
|
||||
has_tracer_wrap
|
||||
), f"Auth function on line {line_number} is not wrapped with @tracer.wrap: {line.strip()}"
|
||||
|
||||
|
||||
def test_get_aws_region_name_boto3_fallback():
|
||||
"""
|
||||
Test the boto3 session fallback logic in _get_aws_region_name method.
|
||||
|
||||
This tests the specific code block that tries to get the region from boto3.Session()
|
||||
when aws_region_name is None and not found in environment variables.
|
||||
"""
|
||||
base_aws_llm = BaseAWSLLM()
|
||||
|
||||
# Test case 1: boto3.Session() returns a configured region
|
||||
with patch("litellm.llms.bedrock.base_aws_llm.get_secret") as mock_get_secret:
|
||||
mock_get_secret.return_value = None # No region in env vars
|
||||
|
||||
with patch("boto3.Session") as mock_boto3_session:
|
||||
mock_session = MagicMock()
|
||||
mock_session.region_name = "us-east-1"
|
||||
mock_boto3_session.return_value = mock_session
|
||||
|
||||
optional_params = {}
|
||||
result = base_aws_llm._get_aws_region_name(optional_params)
|
||||
|
||||
assert result == "us-east-1"
|
||||
mock_boto3_session.assert_called_once()
|
||||
|
||||
# Test case 2: boto3.Session() returns None for region (should default to us-west-2)
|
||||
with patch("litellm.llms.bedrock.base_aws_llm.get_secret") as mock_get_secret:
|
||||
mock_get_secret.return_value = None # No region in env vars
|
||||
|
||||
with patch("boto3.Session") as mock_boto3_session:
|
||||
mock_session = MagicMock()
|
||||
mock_session.region_name = None
|
||||
mock_boto3_session.return_value = mock_session
|
||||
|
||||
optional_params = {}
|
||||
result = base_aws_llm._get_aws_region_name(optional_params)
|
||||
|
||||
assert result == "us-west-2"
|
||||
mock_boto3_session.assert_called_once()
|
||||
|
||||
# Test case 3: boto3 import/session creation raises exception (should default to us-west-2)
|
||||
with patch("litellm.llms.bedrock.base_aws_llm.get_secret") as mock_get_secret:
|
||||
mock_get_secret.return_value = None # No region in env vars
|
||||
|
||||
with patch("boto3.Session") as mock_boto3_session:
|
||||
mock_boto3_session.side_effect = Exception("boto3 not available")
|
||||
|
||||
optional_params = {}
|
||||
result = base_aws_llm._get_aws_region_name(optional_params)
|
||||
|
||||
assert result == "us-west-2"
|
||||
mock_boto3_session.assert_called_once()
|
||||
|
||||
# Test case 4: aws_region_name is provided in optional_params (should not use boto3)
|
||||
with patch("boto3.Session") as mock_boto3_session:
|
||||
optional_params = {"aws_region_name": "eu-west-1"}
|
||||
result = base_aws_llm._get_aws_region_name(optional_params)
|
||||
|
||||
assert result == "eu-west-1"
|
||||
mock_boto3_session.assert_not_called()
|
||||
|
||||
# Test case 5: aws_region_name found in environment variables (should not use boto3)
|
||||
with patch("litellm.llms.bedrock.base_aws_llm.get_secret") as mock_get_secret:
|
||||
|
||||
def side_effect(key, default=None):
|
||||
if key == "AWS_REGION_NAME":
|
||||
return "ap-southeast-1"
|
||||
return default
|
||||
|
||||
mock_get_secret.side_effect = side_effect
|
||||
|
||||
with patch("boto3.Session") as mock_boto3_session:
|
||||
optional_params = {}
|
||||
result = base_aws_llm._get_aws_region_name(optional_params)
|
||||
|
||||
assert result == "ap-southeast-1"
|
||||
mock_boto3_session.assert_not_called()
|
||||
|
||||
|
||||
def test_sign_request_with_env_var_bearer_token():
|
||||
# Create instance of actual class
|
||||
llm = BaseAWSLLM()
|
||||
|
||||
# Test data
|
||||
service_name = "bedrock"
|
||||
headers = {"Custom-Header": "test"}
|
||||
optional_params = {}
|
||||
request_data = {"prompt": "test"}
|
||||
api_base = "https://api.example.com"
|
||||
|
||||
# Mock environment variable
|
||||
with patch.dict(os.environ, {"AWS_BEARER_TOKEN_BEDROCK": "test_token"}):
|
||||
# Execute
|
||||
result_headers, result_body = llm._sign_request(
|
||||
service_name=service_name,
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
request_data=request_data,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result_headers["Authorization"] == "Bearer test_token"
|
||||
assert result_headers["Content-Type"] == "application/json"
|
||||
assert result_headers["Custom-Header"] == "test"
|
||||
assert result_body == json.dumps(request_data).encode()
|
||||
|
||||
|
||||
def test_sign_request_with_sigv4():
|
||||
llm = BaseAWSLLM()
|
||||
|
||||
# Mock AWS credentials and SigV4 auth
|
||||
mock_credentials = Credentials("test_key", "test_secret", "test_token")
|
||||
mock_sigv4 = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {
|
||||
"Authorization": "AWS4-HMAC-SHA256 Credential=test",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
mock_request.body = b'{"prompt": "test"}'
|
||||
|
||||
# Test data
|
||||
service_name = "bedrock"
|
||||
headers = {"Custom-Header": "test"}
|
||||
optional_params = {
|
||||
"aws_access_key_id": "test_key",
|
||||
"aws_secret_access_key": "test_secret",
|
||||
"aws_region_name": "us-west-2",
|
||||
}
|
||||
request_data = {"prompt": "test"}
|
||||
api_base = "https://api.example.com"
|
||||
|
||||
# Mock the necessary components
|
||||
with patch("botocore.auth.SigV4Auth", return_value=mock_sigv4), patch(
|
||||
"botocore.awsrequest.AWSRequest", return_value=mock_request
|
||||
), patch.object(
|
||||
llm, "get_credentials", return_value=mock_credentials
|
||||
), patch.object(
|
||||
llm, "_get_aws_region_name", return_value="us-west-2"
|
||||
):
|
||||
result_headers, result_body = llm._sign_request(
|
||||
service_name=service_name,
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
request_data=request_data,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert "Authorization" in result_headers
|
||||
assert result_headers["Authorization"] != "Bearer test_token"
|
||||
assert result_headers["Content-Type"] == "application/json"
|
||||
assert result_body == mock_request.body
|
||||
|
||||
|
||||
def test_sign_request_with_api_key_bearer_token():
|
||||
"""
|
||||
Test that _sign_request uses the api_key parameter as a bearer token when provided
|
||||
"""
|
||||
llm = BaseAWSLLM()
|
||||
|
||||
# Test data
|
||||
service_name = "bedrock"
|
||||
headers = {"Custom-Header": "test"}
|
||||
optional_params = {}
|
||||
request_data = {"prompt": "test"}
|
||||
api_base = "https://api.example.com"
|
||||
api_key = "test_api_key"
|
||||
|
||||
# Execute with api_key parameter
|
||||
result_headers, result_body = llm._sign_request(
|
||||
service_name=service_name,
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
request_data=request_data,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result_headers["Authorization"] == f"Bearer {api_key}"
|
||||
assert result_headers["Content-Type"] == "application/json"
|
||||
assert result_headers["Custom-Header"] == "test"
|
||||
assert result_body == json.dumps(request_data).encode()
|
||||
|
||||
|
||||
def test_get_request_headers_with_env_var_bearer_token():
|
||||
# Setup
|
||||
llm = BaseAWSLLM()
|
||||
credentials = Credentials("test_key", "test_secret", "test_token")
|
||||
headers = {"Content-Type": "application/json"}
|
||||
headers_dict = headers.copy()
|
||||
|
||||
# Create mock request
|
||||
mock_prepared_request = MagicMock(spec=AWSPreparedRequest)
|
||||
mock_request = MagicMock(spec=AWSRequest)
|
||||
mock_request.headers = headers_dict
|
||||
mock_request.prepare.return_value = mock_prepared_request
|
||||
|
||||
def mock_aws_request_init(method, url, data, headers):
|
||||
mock_request.headers.update(headers)
|
||||
return mock_request
|
||||
|
||||
# Test with bearer token
|
||||
with patch.dict(os.environ, {"AWS_BEARER_TOKEN_BEDROCK": "test_token"}), patch(
|
||||
"botocore.awsrequest.AWSRequest", side_effect=mock_aws_request_init
|
||||
):
|
||||
result = llm.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name="us-west-2",
|
||||
extra_headers=None,
|
||||
endpoint_url="https://api.example.com",
|
||||
data='{"prompt": "test"}',
|
||||
headers=headers_dict,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert mock_request.headers["Authorization"] == "Bearer test_token"
|
||||
assert result == mock_prepared_request
|
||||
|
||||
|
||||
def test_get_request_headers_with_sigv4():
|
||||
# Setup
|
||||
llm = BaseAWSLLM()
|
||||
credentials = Credentials("test_key", "test_secret", "test_token")
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# Create mock request and SigV4 instance
|
||||
mock_request = MagicMock(spec=AWSRequest)
|
||||
mock_request.headers = headers.copy()
|
||||
mock_request.prepare.return_value = MagicMock(spec=AWSPreparedRequest)
|
||||
|
||||
mock_sigv4 = MagicMock()
|
||||
|
||||
# Test without bearer token (should use SigV4)
|
||||
with patch.dict(os.environ, {}, clear=True), patch(
|
||||
"botocore.auth.SigV4Auth", return_value=mock_sigv4
|
||||
) as mock_sigv4_class, patch(
|
||||
"botocore.awsrequest.AWSRequest", return_value=mock_request
|
||||
):
|
||||
result = llm.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name="us-west-2",
|
||||
extra_headers=None,
|
||||
endpoint_url="https://api.example.com",
|
||||
data='{"prompt": "test"}',
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Verify SigV4 authentication and result
|
||||
mock_sigv4_class.assert_called_once_with(credentials, "bedrock", "us-west-2")
|
||||
mock_sigv4.add_auth.assert_called_once_with(mock_request)
|
||||
assert result == mock_request.prepare.return_value
|
||||
|
||||
|
||||
def test_get_request_headers_with_api_key_bearer_token():
|
||||
"""
|
||||
Test that get_request_headers uses the api_key parameter as a bearer token when provided
|
||||
"""
|
||||
# Setup
|
||||
llm = BaseAWSLLM()
|
||||
credentials = Credentials("test_key", "test_secret", "test_token")
|
||||
headers = {"Content-Type": "application/json"}
|
||||
headers_dict = headers.copy()
|
||||
api_key = "test_api_key"
|
||||
|
||||
# Create mock request
|
||||
mock_prepared_request = MagicMock(spec=AWSPreparedRequest)
|
||||
mock_request = MagicMock(spec=AWSRequest)
|
||||
mock_request.headers = headers_dict
|
||||
mock_request.prepare.return_value = mock_prepared_request
|
||||
|
||||
def mock_aws_request_init(method, url, data, headers):
|
||||
mock_request.headers.update(headers)
|
||||
return mock_request
|
||||
|
||||
# Test with api_key parameter
|
||||
with patch.dict(os.environ, {}, clear=True), patch(
|
||||
"botocore.awsrequest.AWSRequest", side_effect=mock_aws_request_init
|
||||
):
|
||||
result = llm.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name="us-west-2",
|
||||
extra_headers=None,
|
||||
endpoint_url="https://api.example.com",
|
||||
data='{"prompt": "test"}',
|
||||
headers=headers_dict,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert mock_request.headers["Authorization"] == f"Bearer {api_key}"
|
||||
assert result == mock_prepared_request
|
||||
|
||||
|
||||
def test_role_assumption_without_session_name():
|
||||
"""
|
||||
Test for issue 12583: Role assumption should work when only aws_role_name is provided
|
||||
without aws_session_name. The system should auto-generate a session name.
|
||||
"""
|
||||
base_aws_llm = BaseAWSLLM()
|
||||
|
||||
# Mock the boto3 STS client
|
||||
mock_sts_client = MagicMock()
|
||||
|
||||
# Mock the STS response with proper expiration handling
|
||||
mock_expiry = MagicMock()
|
||||
mock_expiry.tzinfo = timezone.utc
|
||||
current_time = datetime.now(timezone.utc)
|
||||
# Create a timedelta object that returns 3600 when total_seconds() is called
|
||||
time_diff = MagicMock()
|
||||
time_diff.total_seconds.return_value = 3600
|
||||
mock_expiry.__sub__ = MagicMock(return_value=time_diff)
|
||||
|
||||
mock_sts_response = {
|
||||
"Credentials": {
|
||||
"AccessKeyId": "assumed-access-key",
|
||||
"SecretAccessKey": "assumed-secret-key",
|
||||
"SessionToken": "assumed-session-token",
|
||||
"Expiration": mock_expiry,
|
||||
}
|
||||
}
|
||||
mock_sts_client.assume_role.return_value = mock_sts_response
|
||||
|
||||
# Test case 1: aws_role_name provided without aws_session_name
|
||||
with patch("boto3.client", return_value=mock_sts_client):
|
||||
credentials = base_aws_llm.get_credentials(
|
||||
aws_role_name="arn:aws:iam::2222222222222:role/LitellmEvalBedrockRole"
|
||||
)
|
||||
|
||||
# Verify assume_role was called
|
||||
mock_sts_client.assume_role.assert_called_once()
|
||||
|
||||
# Check the call arguments
|
||||
call_args = mock_sts_client.assume_role.call_args
|
||||
assert (
|
||||
call_args[1]["RoleArn"]
|
||||
== "arn:aws:iam::2222222222222:role/LitellmEvalBedrockRole"
|
||||
)
|
||||
# Session name should be auto-generated with format "litellm-session-{timestamp}"
|
||||
assert call_args[1]["RoleSessionName"].startswith("litellm-session-")
|
||||
|
||||
# Verify credentials are returned correctly
|
||||
assert isinstance(credentials, Credentials)
|
||||
assert credentials.access_key == "assumed-access-key"
|
||||
assert credentials.secret_key == "assumed-secret-key"
|
||||
assert credentials.token == "assumed-session-token"
|
||||
|
||||
# Test case 2: Both aws_role_name and aws_session_name provided (existing behavior)
|
||||
mock_sts_client.reset_mock()
|
||||
with patch("boto3.client", return_value=mock_sts_client):
|
||||
credentials = base_aws_llm.get_credentials(
|
||||
aws_role_name="arn:aws:iam::2222222222222:role/LitellmEvalBedrockRole",
|
||||
aws_session_name="my-custom-session",
|
||||
)
|
||||
|
||||
# Verify assume_role was called with custom session name
|
||||
mock_sts_client.assume_role.assert_called_once()
|
||||
call_args = mock_sts_client.assume_role.call_args
|
||||
assert call_args[1]["RoleSessionName"] == "my-custom-session"
|
||||
|
||||
# Test case 3: Verify caching works with auto-generated session names
|
||||
# Clear the cache first
|
||||
base_aws_llm.iam_cache = DualCache()
|
||||
|
||||
mock_sts_client.reset_mock()
|
||||
with patch("boto3.client", return_value=mock_sts_client):
|
||||
# First call
|
||||
credentials1 = base_aws_llm.get_credentials(
|
||||
aws_role_name="arn:aws:iam::2222222222222:role/LitellmEvalBedrockRole"
|
||||
)
|
||||
|
||||
# Second call with same role should use cache (not call assume_role again)
|
||||
credentials2 = base_aws_llm.get_credentials(
|
||||
aws_role_name="arn:aws:iam::2222222222222:role/LitellmEvalBedrockRole"
|
||||
)
|
||||
|
||||
# Should only be called once due to caching
|
||||
assert mock_sts_client.assume_role.call_count == 1
|
@@ -0,0 +1,21 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||
|
||||
|
||||
def test_deepseek_cris():
|
||||
bedrock_model_info = BedrockModelInfo
|
||||
bedrock_route = bedrock_model_info.get_bedrock_route(
|
||||
model="bedrock/us.deepseek.r1-v1:0"
|
||||
)
|
||||
assert bedrock_route == "converse"
|
@@ -0,0 +1,27 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from litellm.llms.bedrock.vector_stores.transformation import BedrockVectorStoreConfig
|
||||
|
||||
|
||||
def test_transform_search_request():
|
||||
"""
|
||||
Test that BedrockVectorStoreConfig correctly transforms search vector store requests.
|
||||
|
||||
Verifies that the transformation creates the proper URL endpoint and request body
|
||||
with the expected retrievalQuery structure.
|
||||
"""
|
||||
config = BedrockVectorStoreConfig()
|
||||
mock_log = MagicMock()
|
||||
mock_log.model_call_details = {}
|
||||
|
||||
url, body = config.transform_search_vector_store_request(
|
||||
vector_store_id="kb123",
|
||||
query="hello",
|
||||
vector_store_search_optional_params={},
|
||||
api_base="https://bedrock-agent-runtime.us-west-2.amazonaws.com/knowledgebases",
|
||||
litellm_logging_obj=mock_log,
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
assert url.endswith("/kb123/retrieve")
|
||||
assert body["retrievalQuery"].get("text") == "hello"
|
Reference in New Issue
Block a user