Added LiteLLM to the stack
This commit is contained in:
153
Development/litellm/tests/image_gen_tests/test_xinference.py
Normal file
153
Development/litellm/tests/image_gen_tests/test_xinference.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import ImageObject
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_xinference_image_generation():
|
||||
"""Test basic xinference image generation with mocked OpenAI client."""
|
||||
|
||||
# Mock OpenAI response
|
||||
mock_openai_response = {
|
||||
"created": 1699623600,
|
||||
"data": [
|
||||
{
|
||||
"url": "https://example.com/image.png"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Create a proper mock response object
|
||||
class MockResponse:
|
||||
def model_dump(self):
|
||||
return mock_openai_response
|
||||
|
||||
# Create a mock client with the images.generate method
|
||||
mock_client = AsyncMock()
|
||||
mock_client.images.generate = AsyncMock(return_value=MockResponse())
|
||||
|
||||
# Capture the actual arguments sent to OpenAI client
|
||||
captured_args = None
|
||||
captured_kwargs = None
|
||||
|
||||
async def capture_generate_call(*args, **kwargs):
|
||||
nonlocal captured_args, captured_kwargs
|
||||
captured_args = args
|
||||
captured_kwargs = kwargs
|
||||
return MockResponse()
|
||||
|
||||
mock_client.images.generate.side_effect = capture_generate_call
|
||||
|
||||
# Mock the _get_openai_client method to return our mock client
|
||||
with patch.object(litellm.main.openai_chat_completions, '_get_openai_client', return_value=mock_client):
|
||||
response = await litellm.aimage_generation(
|
||||
model="xinference/stabilityai/stable-diffusion-3.5-large",
|
||||
prompt="A beautiful sunset over a calm ocean",
|
||||
api_base="http://mock.image.generation.api",
|
||||
)
|
||||
|
||||
# Print the captured arguments for debugging
|
||||
print("Arguments sent to openai_aclient.images.generate:")
|
||||
print("args:", json.dumps(captured_args, indent=4, default=str))
|
||||
print("kwargs:", json.dumps(captured_kwargs, indent=4, default=str))
|
||||
|
||||
# Validate the response
|
||||
assert response is not None
|
||||
assert response.created == 1699623600
|
||||
assert response.data is not None
|
||||
assert len(response.data) == 1
|
||||
assert response.data[0].url == "https://example.com/image.png"
|
||||
|
||||
# Validate that the OpenAI client was called with correct parameters
|
||||
mock_client.images.generate.assert_called_once()
|
||||
assert captured_kwargs is not None
|
||||
assert captured_kwargs["model"] == "stabilityai/stable-diffusion-3.5-large" # xinference/ prefix removed
|
||||
assert captured_kwargs["prompt"] == "A beautiful sunset over a calm ocean"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_xinference_image_generation_with_response_format():
|
||||
"""
|
||||
Test xinference image generation with additional parameters.
|
||||
Ensure all documented params are passed in.
|
||||
|
||||
https://inference.readthedocs.io/en/v1.1.1/reference/generated/xinference.client.handlers.ImageModelHandle.text_to_image.html#xinference.client.handlers.ImageModelHandle.text_to_image
|
||||
"""
|
||||
|
||||
# Mock OpenAI response
|
||||
mock_openai_response = {
|
||||
"created": 1699623600,
|
||||
"data": [
|
||||
{
|
||||
"b64_json": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=="
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Create a proper mock response object
|
||||
class MockResponse:
|
||||
def model_dump(self):
|
||||
return mock_openai_response
|
||||
|
||||
# Create a mock client with the images.generate method
|
||||
mock_client = AsyncMock()
|
||||
mock_client.images.generate = AsyncMock(return_value=MockResponse())
|
||||
|
||||
# Capture the actual arguments sent to OpenAI client
|
||||
captured_args = None
|
||||
captured_kwargs = None
|
||||
|
||||
async def capture_generate_call(*args, **kwargs):
|
||||
nonlocal captured_args, captured_kwargs
|
||||
captured_args = args
|
||||
captured_kwargs = kwargs
|
||||
return MockResponse()
|
||||
|
||||
mock_client.images.generate.side_effect = capture_generate_call
|
||||
|
||||
# Mock the _get_openai_client method to return our mock client
|
||||
with patch.object(litellm.main.openai_chat_completions, '_get_openai_client', return_value=mock_client):
|
||||
response = await litellm.aimage_generation(
|
||||
model="xinference/stabilityai/stable-diffusion-3.5-large",
|
||||
api_base="http://mock.image.generation.api",
|
||||
prompt="A beautiful sunset over a calm ocean",
|
||||
response_format="b64_json",
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
)
|
||||
|
||||
# Print the captured arguments for debugging
|
||||
print("Arguments sent to openai_aclient.images.generate:")
|
||||
print("args:", json.dumps(captured_args, indent=4, default=str))
|
||||
print("kwargs:", json.dumps(captured_kwargs, indent=4, default=str))
|
||||
|
||||
# Validate the response
|
||||
assert response is not None
|
||||
assert response.created == 1699623600
|
||||
assert response.data is not None
|
||||
assert len(response.data) == 1
|
||||
assert response.data[0].b64_json is not None
|
||||
|
||||
# Validate that the OpenAI client was called with correct parameters
|
||||
mock_client.images.generate.assert_called_once()
|
||||
assert captured_kwargs is not None
|
||||
assert captured_kwargs["model"] == "stabilityai/stable-diffusion-3.5-large" # xinference/ prefix removed
|
||||
assert captured_kwargs["prompt"] == "A beautiful sunset over a calm ocean"
|
||||
assert captured_kwargs["response_format"] == "b64_json"
|
||||
assert captured_kwargs["n"] == 1
|
||||
assert captured_kwargs["size"] == "1024x1024"
|
||||
expected_args = ["model", "prompt", "response_format", "n", "size"]
|
||||
# only expected args should be present
|
||||
assert all(arg in captured_kwargs for arg in expected_args)
|
||||
|
Reference in New Issue
Block a user