Files
Homelab/Development/litellm/tests/image_gen_tests/test_image_generation.py

329 lines
11 KiB
Python

# What this tests?
## This tests the litellm support for the openai /generations endpoint
import logging
import os
import sys
import traceback
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from dotenv import load_dotenv
from openai.types.image import Image
from litellm.caching import InMemoryCache
logging.basicConfig(level=logging.DEBUG)
load_dotenv()
import asyncio
import os
import pytest
import litellm
import json
import tempfile
from base_image_generation_test import BaseImageGenTest
import logging
from litellm._logging import verbose_logger
verbose_logger.setLevel(logging.DEBUG)
def get_vertex_ai_creds_json() -> dict:
# Define the path to the vertex_key.json file
print("loading vertex ai credentials")
filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + "/vertex_key.json"
# Read the existing content of the file or create an empty dictionary
try:
with open(vertex_key_path, "r") as file:
# Read the file content
print("Read vertexai file path")
content = file.read()
# If the file is empty or not valid JSON, create an empty dictionary
if not content or not content.strip():
service_account_key_data = {}
else:
# Attempt to load the existing JSON content
file.seek(0)
service_account_key_data = json.load(file)
except FileNotFoundError:
# If the file doesn't exist, create an empty dictionary
service_account_key_data = {}
# Update the service_account_key_data with environment variables
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
private_key = private_key.replace("\\n", "\n")
service_account_key_data["private_key_id"] = private_key_id
service_account_key_data["private_key"] = private_key
return service_account_key_data
def load_vertex_ai_credentials():
# Define the path to the vertex_key.json file
print("loading vertex ai credentials")
filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + "/vertex_key.json"
# Read the existing content of the file or create an empty dictionary
try:
with open(vertex_key_path, "r") as file:
# Read the file content
print("Read vertexai file path")
content = file.read()
# If the file is empty or not valid JSON, create an empty dictionary
if not content or not content.strip():
service_account_key_data = {}
else:
# Attempt to load the existing JSON content
file.seek(0)
service_account_key_data = json.load(file)
except FileNotFoundError:
# If the file doesn't exist, create an empty dictionary
service_account_key_data = {}
# Update the service_account_key_data with environment variables
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
private_key = private_key.replace("\\n", "\n")
service_account_key_data["private_key_id"] = private_key_id
service_account_key_data["private_key"] = private_key
# Create a temporary file
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
# Write the updated content to the temporary files
json.dump(service_account_key_data, temp_file, indent=2)
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
class TestVertexImageGeneration(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
# comment this when running locally
load_vertex_ai_credentials()
litellm.in_memory_llm_clients_cache = InMemoryCache()
return {
"model": "vertex_ai/imagegeneration@006",
"vertex_ai_project": "pathrise-convert-1606954137718",
"vertex_ai_location": "us-central1",
"n": 1,
}
class TestBedrockSd3(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
litellm.in_memory_llm_clients_cache = InMemoryCache()
return {"model": "bedrock/stability.sd3-large-v1:0"}
class TestBedrockSd1(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
litellm.in_memory_llm_clients_cache = InMemoryCache()
return {"model": "bedrock/stability.sd3-large-v1:0"}
class TestBedrockNovaCanvasTextToImage(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
litellm.in_memory_llm_clients_cache = InMemoryCache()
return {
"model": "bedrock/amazon.nova-canvas-v1:0",
"n": 1,
"size": "320x320",
"imageGenerationConfig": {"cfgScale": 6.5, "seed": 12},
"taskType": "TEXT_IMAGE",
"aws_region_name": "us-east-1",
}
class TestBedrockNovaCanvasColorGuidedGeneration(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
litellm.in_memory_llm_clients_cache = InMemoryCache()
return {
"model": "bedrock/amazon.nova-canvas-v1:0",
"n": 1,
"size": "320x320",
"imageGenerationConfig": {"cfgScale":6.5,"seed":12},
"taskType": "COLOR_GUIDED_GENERATION",
"colorGuidedGenerationParams":{"colors":["#FFFFFF"]},
"aws_region_name": "us-east-1",
}
class TestOpenAIDalle3(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
return {"model": "dall-e-3"}
class TestOpenAIGPTImage1(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
return {"model": "gpt-image-1"}
class TestRecraftImageGeneration(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
return {"model": "recraft/recraftv3"}
class TestGoogleImageGen(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
return {"model": "gemini/imagen-4.0-generate-preview-06-06"}
class TestAzureOpenAIDalle3(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
litellm.set_verbose = True
return {
"model": "azure/dall-e-3-test",
"api_version": "2023-12-01-preview",
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
"api_key": os.getenv("AZURE_SWEDEN_API_KEY"),
"metadata": {
"model_info": {
"base_model": "azure/dall-e-3",
}
},
}
class TestAzureFoundryFlux(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
litellm.set_verbose = True
return {
"model": "azure_ai/FLUX.1-Kontext-pro",
"api_base": os.getenv("AZURE_FLUX_API_BASE"),
"api_key": os.getenv("AZURE_GPT5_API_KEY"),
"n": 1,
"quality": "standard",
}
@pytest.mark.flaky(retries=3, delay=1)
def test_image_generation_azure_dall_e_3():
try:
litellm.set_verbose = True
response = litellm.image_generation(
prompt="A cute baby sea otter",
model="azure/dall-e-3-test",
api_version="2023-12-01-preview",
api_base=os.getenv("AZURE_SWEDEN_API_BASE"),
api_key=os.getenv("AZURE_SWEDEN_API_KEY"),
metadata={
"model_info": {
"base_model": "azure/dall-e-3",
}
},
)
print(f"response: {response}")
print("response", response._hidden_params)
assert len(response.data) > 0
except litellm.InternalServerError as e:
pass
except litellm.ContentPolicyViolationError:
pass # OpenAI randomly raises these errors - skip when they occur
except litellm.InternalServerError:
pass
except litellm.RateLimitError as e:
pass
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
if "Connection error" in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_image_generation_openai())
@pytest.mark.skip(reason="model EOL")
@pytest.mark.asyncio
async def test_aimage_generation_bedrock_with_optional_params():
try:
litellm.in_memory_llm_clients_cache = InMemoryCache()
response = await litellm.aimage_generation(
prompt="A cute baby sea otter",
model="bedrock/stability.stable-diffusion-xl-v1",
size="256x256",
)
print(f"response: {response}")
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors skip when they occur
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.asyncio
async def test_gpt_image_1_with_input_fidelity():
"""Test gpt-image-1 with input_fidelity parameter (mocked)"""
from unittest.mock import AsyncMock, patch
# Mock OpenAI response
mock_openai_response = {
"created": 1703658209,
"data": [
{
"url": "https://example.com/generated_image.png"
}
]
}
# Create a proper mock response object
class MockResponse:
def model_dump(self):
return mock_openai_response
# Create a mock client with the images.generate method
mock_client = AsyncMock()
mock_client.images.generate = AsyncMock(return_value=MockResponse())
# Capture the actual arguments sent to OpenAI client
captured_args = None
captured_kwargs = None
async def capture_generate_call(*args, **kwargs):
nonlocal captured_args, captured_kwargs
captured_args = args
captured_kwargs = kwargs
return MockResponse()
mock_client.images.generate.side_effect = capture_generate_call
# Mock the _get_openai_client method to return our mock client
with patch.object(litellm.main.openai_chat_completions, '_get_openai_client', return_value=mock_client):
response = await litellm.aimage_generation(
prompt="A cute baby sea otter",
model="gpt-image-1",
input_fidelity="high",
quality="medium",
size="1024x1024",
)
# Validate the response
assert response is not None
assert response.created == 1703658209
assert response.data is not None
assert len(response.data) == 1
assert response.data[0].url == "https://example.com/generated_image.png"
# Validate that the OpenAI client was called with correct parameters
mock_client.images.generate.assert_called_once()
assert captured_kwargs is not None
assert captured_kwargs["model"] == "gpt-image-1"
assert captured_kwargs["prompt"] == "A cute baby sea otter"
assert captured_kwargs["input_fidelity"] == "high"
assert captured_kwargs["quality"] == "medium"
assert captured_kwargs["size"] == "1024x1024"