Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
AUDIT LOGGING
|
||||
|
||||
All /audit logging endpoints. Attempting to write these as CRUD endpoints.
|
||||
|
||||
GET - /audit/{id} - Get audit log by id
|
||||
GET - /audit - Get all audit logs
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
#### AUDIT LOGGING ####
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from litellm_enterprise.types.proxy.audit_logging_endpoints import (
|
||||
AuditLogResponse,
|
||||
PaginatedAuditLogResponse,
|
||||
)
|
||||
|
||||
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/audit",
|
||||
tags=["Audit Logging"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=PaginatedAuditLogResponse,
|
||||
)
|
||||
async def get_audit_logs(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
# Filter parameters
|
||||
changed_by: Optional[str] = Query(
|
||||
None, description="Filter by user or system that performed the action"
|
||||
),
|
||||
changed_by_api_key: Optional[str] = Query(
|
||||
None, description="Filter by API key hash that performed the action"
|
||||
),
|
||||
action: Optional[str] = Query(
|
||||
None, description="Filter by action type (create, update, delete)"
|
||||
),
|
||||
table_name: Optional[str] = Query(
|
||||
None, description="Filter by table name that was modified"
|
||||
),
|
||||
object_id: Optional[str] = Query(
|
||||
None, description="Filter by ID of the object that was modified"
|
||||
),
|
||||
start_date: Optional[str] = Query(None, description="Filter logs after this date"),
|
||||
end_date: Optional[str] = Query(None, description="Filter logs before this date"),
|
||||
# Sorting parameters
|
||||
sort_by: Optional[str] = Query(
|
||||
None,
|
||||
description="Column to sort by (e.g. 'updated_at', 'action', 'table_name')",
|
||||
),
|
||||
sort_order: str = Query("desc", description="Sort order ('asc' or 'desc')"),
|
||||
):
|
||||
"""
|
||||
Get all audit logs with filtering and pagination.
|
||||
|
||||
Returns a paginated response of audit logs matching the specified filters.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"message": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# Build filter conditions
|
||||
where_conditions: Dict[str, Any] = {}
|
||||
if changed_by:
|
||||
where_conditions["changed_by"] = changed_by
|
||||
if changed_by_api_key:
|
||||
where_conditions["changed_by_api_key"] = changed_by_api_key
|
||||
if action:
|
||||
where_conditions["action"] = action
|
||||
if table_name:
|
||||
where_conditions["table_name"] = table_name
|
||||
if object_id:
|
||||
where_conditions["object_id"] = object_id
|
||||
if start_date or end_date:
|
||||
date_filter = {}
|
||||
if start_date:
|
||||
date_filter["gte"] = start_date
|
||||
if end_date:
|
||||
date_filter["lte"] = end_date
|
||||
where_conditions["updated_at"] = date_filter
|
||||
|
||||
# Build sort conditions
|
||||
order_by = {}
|
||||
if sort_by and isinstance(sort_by, str):
|
||||
order_by[sort_by] = sort_order
|
||||
elif sort_order and isinstance(sort_order, str):
|
||||
order_by["updated_at"] = sort_order # Default sort by updated_at
|
||||
|
||||
# Get paginated results
|
||||
audit_logs = await prisma_client.db.litellm_auditlog.find_many(
|
||||
where=where_conditions,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
total_count = await prisma_client.db.litellm_auditlog.count(where=where_conditions)
|
||||
total_pages = -(-total_count // page_size) # Ceiling division
|
||||
|
||||
# Return paginated response
|
||||
return PaginatedAuditLogResponse(
|
||||
audit_logs=[
|
||||
AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs
|
||||
]
|
||||
if audit_logs
|
||||
else [],
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/audit/{id}",
|
||||
tags=["Audit Logging"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AuditLogResponse,
|
||||
responses={
|
||||
404: {"description": "Audit log not found"},
|
||||
500: {"description": "Database connection error"},
|
||||
},
|
||||
)
|
||||
async def get_audit_log_by_id(
|
||||
id: str, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)
|
||||
):
|
||||
"""
|
||||
Get detailed information about a specific audit log entry by its ID.
|
||||
|
||||
Args:
|
||||
id (str): The unique identifier of the audit log entry
|
||||
|
||||
Returns:
|
||||
AuditLogResponse: Detailed information about the audit log entry
|
||||
|
||||
Raises:
|
||||
HTTPException: If the audit log is not found or if there's a database connection error
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"message": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# Get the audit log by ID
|
||||
audit_log = await prisma_client.db.litellm_auditlog.find_unique(where={"id": id})
|
||||
|
||||
if audit_log is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail={"message": f"Audit log with ID {id} not found"}
|
||||
)
|
||||
|
||||
# Convert to response model
|
||||
return AuditLogResponse(**audit_log.model_dump())
|
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Enterprise Authentication Module for LiteLLM Proxy
|
||||
|
||||
This module contains enterprise-specific authentication functionality,
|
||||
including custom SSO handlers and advanced authentication features.
|
||||
"""
|
||||
|
||||
from .custom_sso_handler import EnterpriseCustomSSOHandler
|
||||
|
||||
__all__ = ["EnterpriseCustomSSOHandler"]
|
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Enterprise Custom SSO Handler for LiteLLM Proxy
|
||||
|
||||
This module contains enterprise-specific custom SSO authentication functionality
|
||||
that allows users to implement their own SSO handling logic by providing custom
|
||||
handlers that process incoming request headers and return OpenID objects.
|
||||
|
||||
Use this when you have an OAuth proxy in front of LiteLLM (where the OAuth proxy
|
||||
has already authenticated the user) and you need to extract user information from
|
||||
custom headers or other request attributes.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Union, cast
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi_sso.sso.base import OpenID
|
||||
else:
|
||||
from typing import Any as OpenID
|
||||
|
||||
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||
|
||||
|
||||
class EnterpriseCustomSSOHandler:
|
||||
"""
|
||||
Enterprise Custom SSO Handler for LiteLLM Proxy
|
||||
|
||||
This class provides methods for handling custom SSO authentication flows
|
||||
where users can implement their own authentication logic by processing
|
||||
request headers and returning user information in OpenID format.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def handle_custom_ui_sso_sign_in(
|
||||
request: Request,
|
||||
) -> RedirectResponse:
|
||||
"""
|
||||
Allow a user to execute their custom code to parse incoming request headers and return a OpenID object
|
||||
|
||||
Use this when you have an OAuth proxy in front of LiteLLM (where the OAuth proxy has already authenticated the user)
|
||||
|
||||
Args:
|
||||
request: The FastAPI request object containing headers and other request data
|
||||
|
||||
Returns:
|
||||
RedirectResponse: Redirect response that sends the user to the LiteLLM UI with authentication token
|
||||
|
||||
Raises:
|
||||
ValueError: If custom_ui_sso_sign_in_handler is not configured
|
||||
|
||||
Example:
|
||||
This method is typically called when a user has already been authenticated by an
|
||||
external OAuth proxy and the proxy has added custom headers containing user information.
|
||||
The custom handler extracts this information and converts it to an OpenID object.
|
||||
"""
|
||||
from fastapi_sso.sso.base import OpenID
|
||||
|
||||
from litellm.integrations.custom_sso_handler import CustomSSOLoginHandler
|
||||
from litellm.proxy.proxy_server import (
|
||||
CommonProxyErrors,
|
||||
premium_user,
|
||||
user_custom_ui_sso_sign_in_handler,
|
||||
)
|
||||
if premium_user is not True:
|
||||
raise ValueError(CommonProxyErrors.not_premium_user.value)
|
||||
|
||||
if user_custom_ui_sso_sign_in_handler is None:
|
||||
raise ValueError("custom_ui_sso_sign_in_handler is not configured. Please set it in general_settings.")
|
||||
|
||||
custom_sso_login_handler = cast(CustomSSOLoginHandler, user_custom_ui_sso_sign_in_handler)
|
||||
openid_response: OpenID = await custom_sso_login_handler.handle_custom_ui_sso_sign_in(
|
||||
request=request,
|
||||
)
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from litellm.proxy.management_endpoints.ui_sso import SSOAuthenticationHandler
|
||||
|
||||
return await SSOAuthenticationHandler.get_redirect_response_from_openid(
|
||||
result=openid_response,
|
||||
request=request,
|
||||
received_response=None,
|
||||
generic_client_id=None,
|
||||
ui_access_mode=None,
|
||||
)
|
@@ -0,0 +1,66 @@
|
||||
import os
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
class EnterpriseRouteChecks:
|
||||
@staticmethod
|
||||
def is_llm_api_route_disabled() -> bool:
|
||||
"""
|
||||
Check if llm api route is disabled
|
||||
"""
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
from litellm.secret_managers.main import get_secret_bool
|
||||
|
||||
## Check if DISABLE_LLM_API_ENDPOINTS is set
|
||||
if "DISABLE_LLM_API_ENDPOINTS" in os.environ:
|
||||
if not premium_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"🚨🚨🚨 DISABLING LLM API ENDPOINTS is an Enterprise feature\n🚨 {CommonProxyErrors.not_premium_user.value}",
|
||||
)
|
||||
|
||||
return get_secret_bool("DISABLE_LLM_API_ENDPOINTS") is True
|
||||
|
||||
@staticmethod
|
||||
def is_management_routes_disabled() -> bool:
|
||||
"""
|
||||
Check if management route is disabled
|
||||
"""
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
from litellm.secret_managers.main import get_secret_bool
|
||||
|
||||
if "DISABLE_ADMIN_ENDPOINTS" in os.environ:
|
||||
if not premium_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"🚨🚨🚨 DISABLING LLM API ENDPOINTS is an Enterprise feature\n🚨 {CommonProxyErrors.not_premium_user.value}",
|
||||
)
|
||||
|
||||
return get_secret_bool("DISABLE_ADMIN_ENDPOINTS") is True
|
||||
|
||||
@staticmethod
|
||||
def should_call_route(route: str):
|
||||
"""
|
||||
Check if management route is disabled and raise exception
|
||||
"""
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
|
||||
if (
|
||||
RouteChecks.is_management_route(route=route)
|
||||
and EnterpriseRouteChecks.is_management_routes_disabled()
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Management routes are disabled for this instance.",
|
||||
)
|
||||
elif (
|
||||
RouteChecks.is_llm_api_route(route=route)
|
||||
and EnterpriseRouteChecks.is_llm_api_route_disabled()
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="LLM API routes are disabled for this instance.",
|
||||
)
|
@@ -0,0 +1,35 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import ProxyException, UserAPIKeyAuth
|
||||
|
||||
|
||||
async def enterprise_custom_auth(
|
||||
request: Request, api_key: str, user_custom_auth: Optional[Any]
|
||||
) -> Optional[UserAPIKeyAuth]:
|
||||
from litellm_enterprise.proxy.proxy_server import custom_auth_settings
|
||||
|
||||
if user_custom_auth is None:
|
||||
return None
|
||||
|
||||
if custom_auth_settings is None:
|
||||
return await user_custom_auth(request, api_key)
|
||||
|
||||
if custom_auth_settings["mode"] == "on":
|
||||
return await user_custom_auth(request, api_key)
|
||||
elif custom_auth_settings["mode"] == "off":
|
||||
return None
|
||||
elif custom_auth_settings["mode"] == "auto":
|
||||
try:
|
||||
return await user_custom_auth(request, api_key)
|
||||
except ProxyException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Error in custom auth, checking litellm auth: {e}"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {custom_auth_settings['mode']}")
|
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Polls LiteLLM_ManagedObjectTable to check if the batch job is complete, and if the cost has been tracked.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
from litellm.router import Router
|
||||
|
||||
|
||||
class CheckBatchCost:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_logging_obj: "ProxyLogging",
|
||||
prisma_client: "PrismaClient",
|
||||
llm_router: "Router",
|
||||
):
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
from litellm.router import Router
|
||||
|
||||
self.proxy_logging_obj: ProxyLogging = proxy_logging_obj
|
||||
self.prisma_client: PrismaClient = prisma_client
|
||||
self.llm_router: Router = llm_router
|
||||
|
||||
async def check_batch_cost(self):
|
||||
"""
|
||||
Check if the batch JOB has been tracked.
|
||||
- get all status="validating" and file_purpose="batch" jobs
|
||||
- check if batch is now complete
|
||||
- if not, return False
|
||||
- if so, return True
|
||||
"""
|
||||
from litellm_enterprise.proxy.hooks.managed_files import (
|
||||
_PROXY_LiteLLMManagedFiles,
|
||||
)
|
||||
|
||||
from litellm.batches.batch_utils import (
|
||||
_get_file_content_as_dictionary,
|
||||
calculate_batch_cost_and_usage,
|
||||
)
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from litellm.proxy.openai_files_endpoints.common_utils import (
|
||||
_is_base64_encoded_unified_file_id,
|
||||
get_batch_id_from_unified_batch_id,
|
||||
get_model_id_from_unified_batch_id,
|
||||
)
|
||||
|
||||
jobs = await self.prisma_client.db.litellm_managedobjecttable.find_many(
|
||||
where={
|
||||
"status": "validating",
|
||||
"file_purpose": "batch",
|
||||
}
|
||||
)
|
||||
|
||||
completed_jobs = []
|
||||
|
||||
for job in jobs:
|
||||
# get the model from the job
|
||||
unified_object_id = job.unified_object_id
|
||||
decoded_unified_object_id = _is_base64_encoded_unified_file_id(
|
||||
unified_object_id
|
||||
)
|
||||
if not decoded_unified_object_id:
|
||||
verbose_proxy_logger.info(
|
||||
f"Skipping job {unified_object_id} because it is not a valid unified object id"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
unified_object_id = decoded_unified_object_id
|
||||
|
||||
model_id = get_model_id_from_unified_batch_id(unified_object_id)
|
||||
batch_id = get_batch_id_from_unified_batch_id(unified_object_id)
|
||||
|
||||
if model_id is None:
|
||||
verbose_proxy_logger.info(
|
||||
f"Skipping job {unified_object_id} because it is not a valid model id"
|
||||
)
|
||||
continue
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Querying model ID: {model_id} for cost and usage of batch ID: {batch_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.llm_router.aretrieve_batch(
|
||||
model=model_id,
|
||||
batch_id=batch_id,
|
||||
litellm_metadata={
|
||||
"user_api_key_user_id": job.created_by or "default-user-id",
|
||||
"batch_ignore_default_logging": True,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
f"Skipping job {unified_object_id} because of error querying model ID: {model_id} for cost and usage of batch ID: {batch_id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
## RETRIEVE THE BATCH JOB OUTPUT FILE
|
||||
managed_files_obj = cast(
|
||||
Optional[_PROXY_LiteLLMManagedFiles],
|
||||
self.proxy_logging_obj.get_proxy_hook("managed_files"),
|
||||
)
|
||||
if (
|
||||
response.status == "completed"
|
||||
and response.output_file_id is not None
|
||||
and managed_files_obj is not None
|
||||
):
|
||||
verbose_proxy_logger.info(
|
||||
f"Batch ID: {batch_id} is complete, tracking cost and usage"
|
||||
)
|
||||
# track cost
|
||||
model_file_id_mapping = {
|
||||
response.output_file_id: {model_id: response.output_file_id}
|
||||
}
|
||||
_file_content = await managed_files_obj.afile_content(
|
||||
file_id=response.output_file_id,
|
||||
litellm_parent_otel_span=None,
|
||||
llm_router=self.llm_router,
|
||||
model_file_id_mapping=model_file_id_mapping,
|
||||
)
|
||||
|
||||
file_content_as_dict = _get_file_content_as_dictionary(
|
||||
_file_content.content
|
||||
)
|
||||
|
||||
deployment_info = self.llm_router.get_deployment(model_id=model_id)
|
||||
if deployment_info is None:
|
||||
verbose_proxy_logger.info(
|
||||
f"Skipping job {unified_object_id} because it is not a valid deployment info"
|
||||
)
|
||||
continue
|
||||
custom_llm_provider = deployment_info.litellm_params.custom_llm_provider
|
||||
litellm_model_name = deployment_info.litellm_params.model
|
||||
|
||||
_, llm_provider, _, _ = get_llm_provider(
|
||||
model=litellm_model_name,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
batch_cost, batch_usage, batch_models = (
|
||||
await calculate_batch_cost_and_usage(
|
||||
file_content_dictionary=file_content_as_dict,
|
||||
custom_llm_provider=llm_provider, # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
logging_obj = LiteLLMLogging(
|
||||
model=batch_models[0],
|
||||
messages=[{"role": "user", "content": "<retrieve_batch>"}],
|
||||
stream=False,
|
||||
call_type="aretrieve_batch",
|
||||
start_time=datetime.now(),
|
||||
litellm_call_id=str(uuid.uuid4()),
|
||||
function_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
logging_obj.update_environment_variables(
|
||||
litellm_params={
|
||||
"metadata": {
|
||||
"user_api_key_user_id": job.created_by or "default-user-id",
|
||||
}
|
||||
},
|
||||
optional_params={},
|
||||
)
|
||||
|
||||
await logging_obj.async_success_handler(
|
||||
result=response,
|
||||
batch_cost=batch_cost,
|
||||
batch_usage=batch_usage,
|
||||
batch_models=batch_models,
|
||||
)
|
||||
|
||||
# mark the job as complete
|
||||
completed_jobs.append(job)
|
||||
|
||||
if len(completed_jobs) > 0:
|
||||
# mark the jobs as complete
|
||||
await self.prisma_client.db.litellm_managedobjecttable.update_many(
|
||||
where={"id": {"in": [job.id for job in completed_jobs]}},
|
||||
data={"status": "complete"},
|
||||
)
|
@@ -0,0 +1,30 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import Response
|
||||
from litellm_enterprise.enterprise_callbacks.send_emails.endpoints import (
|
||||
router as email_events_router,
|
||||
)
|
||||
|
||||
from .audit_logging_endpoints import router as audit_logging_router
|
||||
from .guardrails.endpoints import router as guardrails_router
|
||||
from .management_endpoints import management_endpoints_router
|
||||
from .utils import _should_block_robots
|
||||
from .vector_stores.endpoints import router as vector_stores_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(vector_stores_router)
|
||||
router.include_router(guardrails_router)
|
||||
router.include_router(email_events_router)
|
||||
router.include_router(audit_logging_router)
|
||||
router.include_router(management_endpoints_router)
|
||||
|
||||
|
||||
@router.get("/robots.txt")
|
||||
async def get_robots():
|
||||
"""
|
||||
Block all web crawlers from indexing the proxy server endpoints
|
||||
This is useful for ensuring that the API endpoints aren't indexed by search engines
|
||||
"""
|
||||
if _should_block_robots():
|
||||
return Response(content="User-agent: *\nDisallow: /", media_type="text/plain")
|
||||
else:
|
||||
return Response(status_code=404)
|
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Enterprise Guardrail Routes on LiteLLM Proxy
|
||||
|
||||
To see all free guardrails see litellm/proxy/guardrails/*
|
||||
|
||||
|
||||
Exposed Routes:
|
||||
- /mask_pii
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.guardrails.guardrail_endpoints import GUARDRAIL_REGISTRY
|
||||
from litellm.types.guardrails import ApplyGuardrailRequest, ApplyGuardrailResponse
|
||||
|
||||
router = APIRouter(tags=["guardrails"], prefix="/guardrails")
|
||||
|
||||
|
||||
@router.post("/apply_guardrail", response_model=ApplyGuardrailResponse)
|
||||
async def apply_guardrail(
|
||||
request: ApplyGuardrailRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Mask PII from a given text, requires a guardrail to be added to litellm.
|
||||
"""
|
||||
active_guardrail: Optional[
|
||||
CustomGuardrail
|
||||
] = GUARDRAIL_REGISTRY.get_initialized_guardrail_callback(
|
||||
guardrail_name=request.guardrail_name
|
||||
)
|
||||
if active_guardrail is None:
|
||||
raise Exception(f"Guardrail {request.guardrail_name} not found")
|
||||
|
||||
return await active_guardrail.apply_guardrail(
|
||||
text=request.text, language=request.language, entities=request.entities
|
||||
)
|
@@ -0,0 +1,830 @@
|
||||
# What is this?
|
||||
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm import Router, verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||
from litellm.llms.base_llm.files.transformation import BaseFileEndpoints
|
||||
from litellm.proxy._types import (
|
||||
CallTypes,
|
||||
LiteLLM_ManagedFileTable,
|
||||
LiteLLM_ManagedObjectTable,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.openai_files_endpoints.common_utils import (
|
||||
_is_base64_encoded_unified_file_id,
|
||||
convert_b64_uid_to_unified_uid,
|
||||
get_batch_id_from_unified_batch_id,
|
||||
get_model_id_from_unified_batch_id,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
AsyncCursorPage,
|
||||
ChatCompletionFileObject,
|
||||
CreateFileRequest,
|
||||
FileObject,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilesPurpose,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
LiteLLMBatch,
|
||||
LiteLLMFineTuningJob,
|
||||
LLMResponseTypes,
|
||||
SpecialEnums,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||
from litellm.proxy.utils import PrismaClient as _PrismaClient
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
InternalUsageCache = _InternalUsageCache
|
||||
PrismaClient = _PrismaClient
|
||||
else:
|
||||
Span = Any
|
||||
InternalUsageCache = Any
|
||||
PrismaClient = Any
|
||||
|
||||
|
||||
class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient
|
||||
):
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
self.prisma_client = prisma_client
|
||||
|
||||
async def store_unified_file_id(
|
||||
self,
|
||||
file_id: str,
|
||||
file_object: Optional[OpenAIFileObject],
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
model_mappings: Dict[str, str],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> None:
|
||||
verbose_logger.info(
|
||||
f"Storing LiteLLM Managed File object with id={file_id} in cache"
|
||||
)
|
||||
if file_object is not None:
|
||||
litellm_managed_file_object = LiteLLM_ManagedFileTable(
|
||||
unified_file_id=file_id,
|
||||
file_object=file_object,
|
||||
model_mappings=model_mappings,
|
||||
flat_model_file_ids=list(model_mappings.values()),
|
||||
created_by=user_api_key_dict.user_id,
|
||||
updated_by=user_api_key_dict.user_id,
|
||||
)
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=file_id,
|
||||
value=litellm_managed_file_object.model_dump(),
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
## STORE MODEL MAPPINGS IN DB
|
||||
|
||||
db_data = {
|
||||
"unified_file_id": file_id,
|
||||
"model_mappings": json.dumps(model_mappings),
|
||||
"flat_model_file_ids": list(model_mappings.values()),
|
||||
"created_by": user_api_key_dict.user_id,
|
||||
"updated_by": user_api_key_dict.user_id,
|
||||
}
|
||||
|
||||
if file_object is not None:
|
||||
db_data["file_object"] = file_object.model_dump_json()
|
||||
|
||||
result = await self.prisma_client.db.litellm_managedfiletable.create(
|
||||
data=db_data
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"LiteLLM Managed File object with id={file_id} stored in db: {result}"
|
||||
)
|
||||
|
||||
async def store_unified_object_id(
|
||||
self,
|
||||
unified_object_id: str,
|
||||
file_object: Union[LiteLLMBatch, LiteLLMFineTuningJob],
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
model_object_id: str,
|
||||
file_purpose: Literal["batch", "fine-tune"],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> None:
|
||||
verbose_logger.info(
|
||||
f"Storing LiteLLM Managed {file_purpose} object with id={unified_object_id} in cache"
|
||||
)
|
||||
litellm_managed_object = LiteLLM_ManagedObjectTable(
|
||||
unified_object_id=unified_object_id,
|
||||
model_object_id=model_object_id,
|
||||
file_purpose=file_purpose,
|
||||
file_object=file_object,
|
||||
)
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=unified_object_id,
|
||||
value=litellm_managed_object.model_dump(),
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
await self.prisma_client.db.litellm_managedobjecttable.upsert(
|
||||
where={"unified_object_id": unified_object_id},
|
||||
data={
|
||||
"create": {
|
||||
"unified_object_id": unified_object_id,
|
||||
"file_object": file_object.model_dump_json(),
|
||||
"model_object_id": model_object_id,
|
||||
"file_purpose": file_purpose,
|
||||
"created_by": user_api_key_dict.user_id,
|
||||
"updated_by": user_api_key_dict.user_id,
|
||||
"status": file_object.status,
|
||||
},
|
||||
"update": {}, # don't do anything if it already exists
|
||||
}
|
||||
)
|
||||
|
||||
async def get_unified_file_id(
|
||||
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
|
||||
) -> Optional[LiteLLM_ManagedFileTable]:
|
||||
## CHECK CACHE
|
||||
result = cast(
|
||||
Optional[dict],
|
||||
await self.internal_usage_cache.async_get_cache(
|
||||
key=file_id,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
),
|
||||
)
|
||||
|
||||
if result:
|
||||
return LiteLLM_ManagedFileTable(**result)
|
||||
|
||||
## CHECK DB
|
||||
db_object = await self.prisma_client.db.litellm_managedfiletable.find_first(
|
||||
where={"unified_file_id": file_id}
|
||||
)
|
||||
|
||||
if db_object:
|
||||
return LiteLLM_ManagedFileTable(**db_object.model_dump())
|
||||
return None
|
||||
|
||||
async def delete_unified_file_id(
|
||||
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
|
||||
) -> OpenAIFileObject:
|
||||
## get old value
|
||||
initial_value = await self.prisma_client.db.litellm_managedfiletable.find_first(
|
||||
where={"unified_file_id": file_id}
|
||||
)
|
||||
if initial_value is None:
|
||||
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||
## delete old value
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=file_id,
|
||||
value=None,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
await self.prisma_client.db.litellm_managedfiletable.delete(
|
||||
where={"unified_file_id": file_id}
|
||||
)
|
||||
return initial_value.file_object
|
||||
|
||||
async def can_user_call_unified_file_id(
|
||||
self, unified_file_id: str, user_api_key_dict: UserAPIKeyAuth
|
||||
) -> bool:
|
||||
## check if the user has access to the unified file id
|
||||
|
||||
user_id = user_api_key_dict.user_id
|
||||
managed_file = await self.prisma_client.db.litellm_managedfiletable.find_first(
|
||||
where={"unified_file_id": unified_file_id}
|
||||
)
|
||||
|
||||
if managed_file:
|
||||
return managed_file.created_by == user_id
|
||||
return False
|
||||
|
||||
async def can_user_call_unified_object_id(
|
||||
self, unified_object_id: str, user_api_key_dict: UserAPIKeyAuth
|
||||
) -> bool:
|
||||
## check if the user has access to the unified object id
|
||||
## check if the user has access to the unified object id
|
||||
user_id = user_api_key_dict.user_id
|
||||
managed_object = (
|
||||
await self.prisma_client.db.litellm_managedobjecttable.find_first(
|
||||
where={"unified_object_id": unified_object_id}
|
||||
)
|
||||
)
|
||||
if managed_object:
|
||||
return managed_object.created_by == user_id
|
||||
return False
|
||||
|
||||
async def get_user_created_file_ids(
|
||||
self, user_api_key_dict: UserAPIKeyAuth, model_object_ids: List[str]
|
||||
) -> List[OpenAIFileObject]:
|
||||
"""
|
||||
Get all file ids created by the user for a list of model object ids
|
||||
|
||||
Returns:
|
||||
- List of OpenAIFileObject's
|
||||
"""
|
||||
file_ids = await self.prisma_client.db.litellm_managedfiletable.find_many(
|
||||
where={
|
||||
"created_by": user_api_key_dict.user_id,
|
||||
"flat_model_file_ids": {"hasSome": model_object_ids},
|
||||
}
|
||||
)
|
||||
return [OpenAIFileObject(**file_object.file_object) for file_object in file_ids]
|
||||
|
||||
async def check_managed_file_id_access(
|
||||
self, data: Dict, user_api_key_dict: UserAPIKeyAuth
|
||||
) -> bool:
|
||||
retrieve_file_id = cast(Optional[str], data.get("file_id"))
|
||||
potential_file_id = (
|
||||
_is_base64_encoded_unified_file_id(retrieve_file_id)
|
||||
if retrieve_file_id
|
||||
else False
|
||||
)
|
||||
if potential_file_id and retrieve_file_id:
|
||||
if await self.can_user_call_unified_file_id(
|
||||
retrieve_file_id, user_api_key_dict
|
||||
):
|
||||
return True
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"User {user_api_key_dict.user_id} does not have access to the file {retrieve_file_id}",
|
||||
)
|
||||
return False
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: Dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"pass_through_endpoint",
|
||||
"rerank",
|
||||
"acreate_batch",
|
||||
"aretrieve_batch",
|
||||
"acreate_file",
|
||||
"afile_list",
|
||||
"afile_delete",
|
||||
"afile_content",
|
||||
"acreate_fine_tuning_job",
|
||||
"aretrieve_fine_tuning_job",
|
||||
"alist_fine_tuning_jobs",
|
||||
"acancel_fine_tuning_job",
|
||||
"mcp_call",
|
||||
],
|
||||
) -> Union[Exception, str, Dict, None]:
|
||||
"""
|
||||
- Detect litellm_proxy/ file_id
|
||||
- add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}}
|
||||
"""
|
||||
### HANDLE FILE ACCESS ### - ensure user has access to the file
|
||||
if (
|
||||
call_type == CallTypes.afile_content.value
|
||||
or call_type == CallTypes.afile_delete.value
|
||||
):
|
||||
await self.check_managed_file_id_access(data, user_api_key_dict)
|
||||
|
||||
### HANDLE TRANSFORMATIONS ###
|
||||
if call_type == CallTypes.completion.value:
|
||||
messages = data.get("messages")
|
||||
if messages:
|
||||
file_ids = self.get_file_ids_from_messages(messages)
|
||||
if file_ids:
|
||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||
file_ids, user_api_key_dict.parent_otel_span
|
||||
)
|
||||
|
||||
data["model_file_id_mapping"] = model_file_id_mapping
|
||||
elif call_type == CallTypes.afile_content.value:
|
||||
retrieve_file_id = cast(Optional[str], data.get("file_id"))
|
||||
potential_file_id = (
|
||||
_is_base64_encoded_unified_file_id(retrieve_file_id)
|
||||
if retrieve_file_id
|
||||
else False
|
||||
)
|
||||
if potential_file_id:
|
||||
model_id = self.get_model_id_from_unified_file_id(potential_file_id)
|
||||
if model_id:
|
||||
data["model"] = model_id
|
||||
data["file_id"] = self.get_output_file_id_from_unified_file_id(
|
||||
potential_file_id
|
||||
)
|
||||
elif call_type == CallTypes.acreate_batch.value:
|
||||
input_file_id = cast(Optional[str], data.get("input_file_id"))
|
||||
if input_file_id:
|
||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||
[input_file_id], user_api_key_dict.parent_otel_span
|
||||
)
|
||||
|
||||
data["model_file_id_mapping"] = model_file_id_mapping
|
||||
elif (
|
||||
call_type == CallTypes.aretrieve_batch.value
|
||||
or call_type == CallTypes.acancel_fine_tuning_job.value
|
||||
or call_type == CallTypes.aretrieve_fine_tuning_job.value
|
||||
):
|
||||
accessor_key: Optional[str] = None
|
||||
retrieve_object_id: Optional[str] = None
|
||||
if call_type == CallTypes.aretrieve_batch.value:
|
||||
accessor_key = "batch_id"
|
||||
elif (
|
||||
call_type == CallTypes.acancel_fine_tuning_job.value
|
||||
or call_type == CallTypes.aretrieve_fine_tuning_job.value
|
||||
):
|
||||
accessor_key = "fine_tuning_job_id"
|
||||
|
||||
if accessor_key:
|
||||
retrieve_object_id = cast(Optional[str], data.get(accessor_key))
|
||||
|
||||
potential_llm_object_id = (
|
||||
_is_base64_encoded_unified_file_id(retrieve_object_id)
|
||||
if retrieve_object_id
|
||||
else False
|
||||
)
|
||||
if potential_llm_object_id and retrieve_object_id:
|
||||
## VALIDATE USER HAS ACCESS TO THE OBJECT ##
|
||||
if not await self.can_user_call_unified_object_id(
|
||||
retrieve_object_id, user_api_key_dict
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"User {user_api_key_dict.user_id} does not have access to the object {retrieve_object_id}",
|
||||
)
|
||||
|
||||
## for managed batch id - get the model id
|
||||
potential_model_id = get_model_id_from_unified_batch_id(
|
||||
potential_llm_object_id
|
||||
)
|
||||
if potential_model_id is None:
|
||||
raise Exception(
|
||||
f"LiteLLM Managed {accessor_key} with id={retrieve_object_id} is invalid - does not contain encoded model_id."
|
||||
)
|
||||
data["model"] = potential_model_id
|
||||
data[accessor_key] = get_batch_id_from_unified_batch_id(
|
||||
potential_llm_object_id
|
||||
)
|
||||
elif call_type == CallTypes.acreate_fine_tuning_job.value:
|
||||
input_file_id = cast(Optional[str], data.get("training_file"))
|
||||
if input_file_id:
|
||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||
[input_file_id], user_api_key_dict.parent_otel_span
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
) -> List[Dict]:
|
||||
if request_kwargs is None:
|
||||
return healthy_deployments
|
||||
|
||||
input_file_id = cast(Optional[str], request_kwargs.get("input_file_id"))
|
||||
model_file_id_mapping = cast(
|
||||
Optional[Dict[str, Dict[str, str]]],
|
||||
request_kwargs.get("model_file_id_mapping"),
|
||||
)
|
||||
allowed_model_ids = []
|
||||
if input_file_id and model_file_id_mapping:
|
||||
model_id_dict = model_file_id_mapping.get(input_file_id, {})
|
||||
allowed_model_ids = list(model_id_dict.keys())
|
||||
|
||||
if len(allowed_model_ids) == 0:
|
||||
return healthy_deployments
|
||||
|
||||
return [
|
||||
deployment
|
||||
for deployment in healthy_deployments
|
||||
if deployment.get("model_info", {}).get("id") in allowed_model_ids
|
||||
]
|
||||
|
||||
async def async_pre_call_deployment_hook(
|
||||
self, kwargs: Dict[str, Any], call_type: Optional[CallTypes]
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Allow modifying the request just before it's sent to the deployment.
|
||||
"""
|
||||
accessor_key: Optional[str] = None
|
||||
if call_type and call_type == CallTypes.acreate_batch:
|
||||
accessor_key = "input_file_id"
|
||||
elif call_type and call_type == CallTypes.acreate_fine_tuning_job:
|
||||
accessor_key = "training_file"
|
||||
else:
|
||||
return kwargs
|
||||
|
||||
if accessor_key:
|
||||
input_file_id = cast(Optional[str], kwargs.get(accessor_key))
|
||||
model_file_id_mapping = cast(
|
||||
Optional[Dict[str, Dict[str, str]]], kwargs.get("model_file_id_mapping")
|
||||
)
|
||||
model_id = cast(Optional[str], kwargs.get("model_info", {}).get("id", None))
|
||||
mapped_file_id: Optional[str] = None
|
||||
if input_file_id and model_file_id_mapping and model_id:
|
||||
mapped_file_id = model_file_id_mapping.get(input_file_id, {}).get(
|
||||
model_id, None
|
||||
)
|
||||
if mapped_file_id:
|
||||
kwargs[accessor_key] = mapped_file_id
|
||||
|
||||
return kwargs
|
||||
|
||||
def get_file_ids_from_messages(self, messages: List[AllMessageValues]) -> List[str]:
|
||||
"""
|
||||
Gets file ids from messages
|
||||
"""
|
||||
file_ids = []
|
||||
for message in messages:
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content")
|
||||
if content:
|
||||
if isinstance(content, str):
|
||||
continue
|
||||
for c in content:
|
||||
if c["type"] == "file":
|
||||
file_object = cast(ChatCompletionFileObject, c)
|
||||
file_object_file_field = file_object["file"]
|
||||
file_id = file_object_file_field.get("file_id")
|
||||
if file_id:
|
||||
file_ids.append(file_id)
|
||||
return file_ids
|
||||
|
||||
async def get_model_file_id_mapping(
|
||||
self, file_ids: List[str], litellm_parent_otel_span: Span
|
||||
) -> dict:
|
||||
"""
|
||||
Get model-specific file IDs for a list of proxy file IDs.
|
||||
Returns a dictionary mapping litellm_proxy/ file_id -> model_id -> model_file_id
|
||||
|
||||
1. Get all the litellm_proxy/ file_ids from the messages
|
||||
2. For each file_id, search for cache keys matching the pattern file_id:*
|
||||
3. Return a dictionary of mappings of litellm_proxy/ file_id -> model_id -> model_file_id
|
||||
|
||||
Example:
|
||||
{
|
||||
"litellm_proxy/file_id": {
|
||||
"model_id": "model_file_id"
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
file_id_mapping: Dict[str, Dict[str, str]] = {}
|
||||
litellm_managed_file_ids = []
|
||||
|
||||
for file_id in file_ids:
|
||||
## CHECK IF FILE ID IS MANAGED BY LITELM
|
||||
is_base64_unified_file_id = _is_base64_encoded_unified_file_id(file_id)
|
||||
|
||||
if is_base64_unified_file_id:
|
||||
litellm_managed_file_ids.append(file_id)
|
||||
|
||||
if litellm_managed_file_ids:
|
||||
# Get all cache keys matching the pattern file_id:*
|
||||
for file_id in litellm_managed_file_ids:
|
||||
# Search for any cache key starting with this file_id
|
||||
unified_file_object = await self.get_unified_file_id(
|
||||
file_id, litellm_parent_otel_span
|
||||
)
|
||||
if unified_file_object:
|
||||
file_id_mapping[file_id] = unified_file_object.model_mappings
|
||||
|
||||
return file_id_mapping
|
||||
|
||||
async def create_file_for_each_model(
|
||||
self,
|
||||
llm_router: Optional[Router],
|
||||
_create_file_request: CreateFileRequest,
|
||||
target_model_names_list: List[str],
|
||||
litellm_parent_otel_span: Span,
|
||||
) -> List[OpenAIFileObject]:
|
||||
if llm_router is None:
|
||||
raise Exception("LLM Router not initialized. Ensure models added to proxy.")
|
||||
responses = []
|
||||
for model in target_model_names_list:
|
||||
individual_response = await llm_router.acreate_file(
|
||||
model=model, **_create_file_request
|
||||
)
|
||||
responses.append(individual_response)
|
||||
|
||||
return responses
|
||||
|
||||
async def acreate_file(
|
||||
self,
|
||||
create_file_request: CreateFileRequest,
|
||||
llm_router: Router,
|
||||
target_model_names_list: List[str],
|
||||
litellm_parent_otel_span: Span,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> OpenAIFileObject:
|
||||
responses = await self.create_file_for_each_model(
|
||||
llm_router=llm_router,
|
||||
_create_file_request=create_file_request,
|
||||
target_model_names_list=target_model_names_list,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id(
|
||||
file_objects=responses,
|
||||
create_file_request=create_file_request,
|
||||
internal_usage_cache=self.internal_usage_cache,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
target_model_names_list=target_model_names_list,
|
||||
)
|
||||
|
||||
## STORE MODEL MAPPINGS IN DB
|
||||
model_mappings: Dict[str, str] = {}
|
||||
|
||||
for file_object in responses:
|
||||
model_file_id_mapping = file_object._hidden_params.get(
|
||||
"model_file_id_mapping"
|
||||
)
|
||||
if model_file_id_mapping and isinstance(model_file_id_mapping, dict):
|
||||
model_mappings.update(model_file_id_mapping)
|
||||
|
||||
await self.store_unified_file_id(
|
||||
file_id=response.id,
|
||||
file_object=response,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
model_mappings=model_mappings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
async def return_unified_file_id(
|
||||
file_objects: List[OpenAIFileObject],
|
||||
create_file_request: CreateFileRequest,
|
||||
internal_usage_cache: InternalUsageCache,
|
||||
litellm_parent_otel_span: Span,
|
||||
target_model_names_list: List[str],
|
||||
) -> OpenAIFileObject:
|
||||
## GET THE FILE TYPE FROM THE CREATE FILE REQUEST
|
||||
file_data = extract_file_data(create_file_request["file"])
|
||||
|
||||
file_type = file_data["content_type"]
|
||||
|
||||
output_file_id = file_objects[0].id
|
||||
model_id = file_objects[0]._hidden_params.get("model_id")
|
||||
|
||||
unified_file_id = SpecialEnums.LITELLM_MANAGED_FILE_COMPLETE_STR.value.format(
|
||||
file_type,
|
||||
str(uuid.uuid4()),
|
||||
",".join(target_model_names_list),
|
||||
output_file_id,
|
||||
model_id,
|
||||
)
|
||||
|
||||
# Convert to URL-safe base64 and strip padding
|
||||
base64_unified_file_id = (
|
||||
base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
## CREATE RESPONSE OBJECT
|
||||
|
||||
response = OpenAIFileObject(
|
||||
id=base64_unified_file_id,
|
||||
object="file",
|
||||
purpose=create_file_request["purpose"],
|
||||
created_at=file_objects[0].created_at,
|
||||
bytes=file_objects[0].bytes,
|
||||
filename=file_objects[0].filename,
|
||||
status="uploaded",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_unified_generic_response_id(
|
||||
self, model_id: str, generic_response_id: str
|
||||
) -> str:
|
||||
unified_generic_response_id = (
|
||||
SpecialEnums.LITELLM_MANAGED_GENERIC_RESPONSE_COMPLETE_STR.value.format(
|
||||
model_id, generic_response_id
|
||||
)
|
||||
)
|
||||
return (
|
||||
base64.urlsafe_b64encode(unified_generic_response_id.encode())
|
||||
.decode()
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
def get_unified_batch_id(self, batch_id: str, model_id: str) -> str:
|
||||
unified_batch_id = SpecialEnums.LITELLM_MANAGED_BATCH_COMPLETE_STR.value.format(
|
||||
model_id, batch_id
|
||||
)
|
||||
return base64.urlsafe_b64encode(unified_batch_id.encode()).decode().rstrip("=")
|
||||
|
||||
def get_unified_output_file_id(
|
||||
self, output_file_id: str, model_id: str, model_name: Optional[str]
|
||||
) -> str:
|
||||
unified_output_file_id = (
|
||||
SpecialEnums.LITELLM_MANAGED_FILE_COMPLETE_STR.value.format(
|
||||
"application/json",
|
||||
str(uuid.uuid4()),
|
||||
model_name or "",
|
||||
output_file_id,
|
||||
model_id,
|
||||
)
|
||||
)
|
||||
return (
|
||||
base64.urlsafe_b64encode(unified_output_file_id.encode())
|
||||
.decode()
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
def get_model_id_from_unified_file_id(self, file_id: str) -> str:
|
||||
return file_id.split("llm_output_file_model_id,")[1].split(";")[0]
|
||||
|
||||
def get_output_file_id_from_unified_file_id(self, file_id: str) -> str:
|
||||
return file_id.split("llm_output_file_id,")[1].split(";")[0]
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self, data: Dict, user_api_key_dict: UserAPIKeyAuth, response: LLMResponseTypes
|
||||
) -> Any:
|
||||
if isinstance(response, LiteLLMBatch):
|
||||
## Check if unified_file_id is in the response
|
||||
unified_file_id = response._hidden_params.get(
|
||||
"unified_file_id"
|
||||
) # managed file id
|
||||
unified_batch_id = response._hidden_params.get(
|
||||
"unified_batch_id"
|
||||
) # managed batch id
|
||||
model_id = cast(Optional[str], response._hidden_params.get("model_id"))
|
||||
model_name = cast(Optional[str], response._hidden_params.get("model_name"))
|
||||
original_response_id = response.id
|
||||
|
||||
if (unified_batch_id or unified_file_id) and model_id:
|
||||
response.id = self.get_unified_batch_id(
|
||||
batch_id=response.id, model_id=model_id
|
||||
)
|
||||
|
||||
if (
|
||||
response.output_file_id and model_id
|
||||
): # return a file id with the model_id and output_file_id
|
||||
original_output_file_id = response.output_file_id
|
||||
response.output_file_id = self.get_unified_output_file_id(
|
||||
output_file_id=response.output_file_id,
|
||||
model_id=model_id,
|
||||
model_name=model_name,
|
||||
)
|
||||
await self.store_unified_file_id( # need to store otherwise any retrieve call will fail
|
||||
file_id=response.output_file_id,
|
||||
file_object=None,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
model_mappings={model_id: original_output_file_id},
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
asyncio.create_task(
|
||||
self.store_unified_object_id(
|
||||
unified_object_id=response.id,
|
||||
file_object=response,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
model_object_id=original_response_id,
|
||||
file_purpose="batch",
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
)
|
||||
elif isinstance(response, LiteLLMFineTuningJob):
|
||||
## Check if unified_file_id is in the response
|
||||
unified_file_id = response._hidden_params.get(
|
||||
"unified_file_id"
|
||||
) # managed file id
|
||||
unified_finetuning_job_id = response._hidden_params.get(
|
||||
"unified_finetuning_job_id"
|
||||
) # managed finetuning job id
|
||||
model_id = cast(Optional[str], response._hidden_params.get("model_id"))
|
||||
model_name = cast(Optional[str], response._hidden_params.get("model_name"))
|
||||
original_response_id = response.id
|
||||
if (unified_file_id or unified_finetuning_job_id) and model_id:
|
||||
response.id = self.get_unified_generic_response_id(
|
||||
model_id=model_id, generic_response_id=response.id
|
||||
)
|
||||
asyncio.create_task(
|
||||
self.store_unified_object_id(
|
||||
unified_object_id=response.id,
|
||||
file_object=response,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
model_object_id=original_response_id,
|
||||
file_purpose="fine-tune",
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
)
|
||||
elif isinstance(response, AsyncCursorPage):
|
||||
"""
|
||||
For listing files, filter for the ones created by the user
|
||||
"""
|
||||
## check if file object
|
||||
if hasattr(response, "data") and isinstance(response.data, list):
|
||||
if all(
|
||||
isinstance(file_object, FileObject) for file_object in response.data
|
||||
):
|
||||
## Get all file id's
|
||||
## Check which file id's were created by the user
|
||||
## Filter the response to only include the files created by the user
|
||||
## Return the filtered response
|
||||
file_ids = [
|
||||
file_object.id
|
||||
for file_object in cast(List[FileObject], response.data) # type: ignore
|
||||
]
|
||||
user_created_file_ids = await self.get_user_created_file_ids(
|
||||
user_api_key_dict, file_ids
|
||||
)
|
||||
## Filter the response to only include the files created by the user
|
||||
response.data = user_created_file_ids # type: ignore
|
||||
return response
|
||||
return response
|
||||
return response
|
||||
|
||||
async def afile_retrieve(
|
||||
self, file_id: str, litellm_parent_otel_span: Optional[Span]
|
||||
) -> OpenAIFileObject:
|
||||
stored_file_object = await self.get_unified_file_id(
|
||||
file_id, litellm_parent_otel_span
|
||||
)
|
||||
if stored_file_object:
|
||||
return stored_file_object.file_object
|
||||
else:
|
||||
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||
|
||||
async def afile_list(
|
||||
self,
|
||||
purpose: Optional[OpenAIFilesPurpose],
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
**data: Dict,
|
||||
) -> List[OpenAIFileObject]:
|
||||
"""Handled in files_endpoints.py"""
|
||||
return []
|
||||
|
||||
async def afile_delete(
|
||||
self,
|
||||
file_id: str,
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
llm_router: Router,
|
||||
**data: Dict,
|
||||
) -> OpenAIFileObject:
|
||||
file_id = convert_b64_uid_to_unified_uid(file_id)
|
||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||
[file_id], litellm_parent_otel_span
|
||||
)
|
||||
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
|
||||
if specific_model_file_id_mapping:
|
||||
for model_id, file_id in specific_model_file_id_mapping.items():
|
||||
await llm_router.afile_delete(model=model_id, file_id=file_id, **data) # type: ignore
|
||||
|
||||
stored_file_object = await self.delete_unified_file_id(
|
||||
file_id, litellm_parent_otel_span
|
||||
)
|
||||
if stored_file_object:
|
||||
return stored_file_object
|
||||
else:
|
||||
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||
|
||||
async def afile_content(
|
||||
self,
|
||||
file_id: str,
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
llm_router: Router,
|
||||
**data: Dict,
|
||||
) -> "HttpxBinaryResponseContent":
|
||||
"""
|
||||
Get the content of a file from first model that has it
|
||||
"""
|
||||
model_file_id_mapping = data.pop("model_file_id_mapping", None)
|
||||
model_file_id_mapping = (
|
||||
model_file_id_mapping
|
||||
or await self.get_model_file_id_mapping([file_id], litellm_parent_otel_span)
|
||||
)
|
||||
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
|
||||
|
||||
if specific_model_file_id_mapping:
|
||||
exception_dict = {}
|
||||
for model_id, file_id in specific_model_file_id_mapping.items():
|
||||
try:
|
||||
return await llm_router.afile_content(model=model_id, file_id=file_id, **data) # type: ignore
|
||||
except Exception as e:
|
||||
exception_dict[model_id] = str(e)
|
||||
raise Exception(
|
||||
f"LiteLLM Managed File object with id={file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}"
|
||||
)
|
||||
else:
|
||||
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
@@ -0,0 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .internal_user_endpoints import router as internal_user_endpoints_router
|
||||
|
||||
management_endpoints_router = APIRouter()
|
||||
management_endpoints_router.include_router(internal_user_endpoints_router)
|
||||
|
||||
__all__ = ["management_endpoints_router"]
|
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Enterprise internal user management endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import user_api_key_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/user/available_users",
|
||||
tags=["Internal User management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def available_enterprise_users(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
For keys with `max_users` set, return the list of users that are allowed to use the key.
|
||||
"""
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.proxy.proxy_server import (
|
||||
premium_user,
|
||||
premium_user_data,
|
||||
prisma_client,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if premium_user is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.not_premium_user.value}
|
||||
)
|
||||
|
||||
# Count number of rows in LiteLLM_UserTable
|
||||
user_count = await prisma_client.db.litellm_usertable.count()
|
||||
team_count = await prisma_client.db.litellm_teamtable.count()
|
||||
|
||||
if (
|
||||
not premium_user_data
|
||||
or premium_user_data is not None
|
||||
and "max_users" not in premium_user_data
|
||||
):
|
||||
max_users = None
|
||||
else:
|
||||
max_users = premium_user_data.get("max_users")
|
||||
|
||||
if premium_user_data and "max_teams" in premium_user_data:
|
||||
max_teams = premium_user_data.get("max_teams")
|
||||
else:
|
||||
max_teams = None
|
||||
|
||||
return {
|
||||
"total_users": max_users,
|
||||
"total_teams": max_teams,
|
||||
"total_users_used": user_count,
|
||||
"total_teams_used": team_count,
|
||||
"total_teams_remaining": (max_teams - team_count if max_teams else None),
|
||||
"total_users_remaining": (max_users - user_count if max_users else None),
|
||||
}
|
@@ -0,0 +1,30 @@
|
||||
from typing import Optional
|
||||
|
||||
from litellm.proxy._types import GenerateKeyRequest, LiteLLM_TeamTable
|
||||
|
||||
|
||||
def add_team_member_key_duration(
|
||||
team_table: Optional[LiteLLM_TeamTable],
|
||||
data: GenerateKeyRequest,
|
||||
) -> GenerateKeyRequest:
|
||||
if team_table is None:
|
||||
return data
|
||||
|
||||
if data.user_id is None: # only apply for team member keys, not service accounts
|
||||
return data
|
||||
|
||||
if (
|
||||
team_table.metadata is not None
|
||||
and team_table.metadata.get("team_member_key_duration") is not None
|
||||
):
|
||||
data.duration = team_table.metadata["team_member_key_duration"]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def apply_enterprise_key_management_params(
|
||||
data: GenerateKeyRequest,
|
||||
team_table: Optional[LiteLLM_TeamTable],
|
||||
) -> GenerateKeyRequest:
|
||||
data = add_team_member_key_duration(team_table, data)
|
||||
return data
|
@@ -0,0 +1,34 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from litellm_enterprise.types.proxy.proxy_server import CustomAuthSettings
|
||||
|
||||
custom_auth_settings: Optional[CustomAuthSettings] = None
|
||||
|
||||
|
||||
class EnterpriseProxyConfig:
|
||||
async def load_custom_auth_settings(
|
||||
self, general_settings: dict
|
||||
) -> CustomAuthSettings:
|
||||
custom_auth_settings = general_settings.get("custom_auth_settings", None)
|
||||
if custom_auth_settings is not None:
|
||||
custom_auth_settings = CustomAuthSettings(
|
||||
mode=custom_auth_settings.get("mode"),
|
||||
)
|
||||
return custom_auth_settings
|
||||
|
||||
async def load_enterprise_config(self, general_settings: dict) -> None:
|
||||
global custom_auth_settings
|
||||
custom_auth_settings = await self.load_custom_auth_settings(general_settings)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_custom_docs_description() -> Optional[str]:
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
docs_description: Optional[str] = None
|
||||
if premium_user:
|
||||
# check if premium_user has custom_docs_description
|
||||
docs_description = os.getenv("DOCS_DESCRIPTION")
|
||||
|
||||
return docs_description
|
@@ -0,0 +1,11 @@
|
||||
# LiteLLM Proxy Enterprise Features - Readme
|
||||
|
||||
## Overview
|
||||
|
||||
This directory contains enterprise features used on the LiteLLM proxy.
|
||||
|
||||
## Format
|
||||
|
||||
Create a file for every group of endpoints (e.g. `key_management_endpoints.py`, `user_management_endpoints.py`, etc.)
|
||||
|
||||
If there is a broader semantic group of endpoints, create a folder for that group (e.g. `management_endpoints`, `auth_endpoints`, etc.)
|
@@ -0,0 +1,35 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
|
||||
|
||||
def _should_block_robots():
|
||||
"""
|
||||
Returns True if the robots.txt file should block web crawlers
|
||||
|
||||
Controlled by
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
block_robots: true
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
CommonProxyErrors,
|
||||
general_settings,
|
||||
premium_user,
|
||||
)
|
||||
|
||||
_block_robots: Union[bool, str] = general_settings.get("block_robots", False)
|
||||
block_robots: Optional[bool] = None
|
||||
if isinstance(_block_robots, bool):
|
||||
block_robots = _block_robots
|
||||
elif isinstance(_block_robots, str):
|
||||
block_robots = str_to_bool(_block_robots)
|
||||
if block_robots is True:
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Blocking web crawlers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
return True
|
||||
return False
|
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
VECTOR STORE MANAGEMENT
|
||||
|
||||
All /vector_store management endpoints
|
||||
|
||||
/vector_store/new
|
||||
/vector_store/delete
|
||||
/vector_store/list
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.vector_stores import (
|
||||
LiteLLM_ManagedVectorStore,
|
||||
LiteLLM_ManagedVectorStoreListResponse,
|
||||
VectorStoreDeleteRequest,
|
||||
VectorStoreInfoRequest,
|
||||
VectorStoreUpdateRequest,
|
||||
)
|
||||
from litellm.vector_stores.vector_store_registry import VectorStoreRegistry
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
########################################################
|
||||
# Management Endpoints
|
||||
########################################################
|
||||
@router.post(
|
||||
"/vector_store/new",
|
||||
tags=["vector store management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def new_vector_store(
|
||||
vector_store: LiteLLM_ManagedVectorStore,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create a new vector store.
|
||||
|
||||
Parameters:
|
||||
- vector_store_id: str - Unique identifier for the vector store
|
||||
- custom_llm_provider: str - Provider of the vector store
|
||||
- vector_store_name: Optional[str] - Name of the vector store
|
||||
- vector_store_description: Optional[str] - Description of the vector store
|
||||
- vector_store_metadata: Optional[Dict] - Additional metadata for the vector store
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Check if vector store already exists
|
||||
existing_vector_store = (
|
||||
await prisma_client.db.litellm_managedvectorstorestable.find_unique(
|
||||
where={"vector_store_id": vector_store.get("vector_store_id")}
|
||||
)
|
||||
)
|
||||
if existing_vector_store is not None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Vector store with ID {vector_store.get('vector_store_id')} already exists",
|
||||
)
|
||||
|
||||
if vector_store.get("vector_store_metadata") is not None:
|
||||
vector_store["vector_store_metadata"] = safe_dumps(
|
||||
vector_store.get("vector_store_metadata")
|
||||
)
|
||||
|
||||
# Safely handle JSON serialization of litellm_params
|
||||
litellm_params_json: Optional[str] = None
|
||||
_input_litellm_params: dict = vector_store.get("litellm_params", {}) or {}
|
||||
if _input_litellm_params is not None:
|
||||
litellm_params_dict = GenericLiteLLMParams(**_input_litellm_params).model_dump(exclude_none=True)
|
||||
litellm_params_json = safe_dumps(litellm_params_dict)
|
||||
del vector_store["litellm_params"]
|
||||
|
||||
_new_vector_store = (
|
||||
await prisma_client.db.litellm_managedvectorstorestable.create(
|
||||
data={
|
||||
**vector_store,
|
||||
"litellm_params": litellm_params_json,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
new_vector_store: LiteLLM_ManagedVectorStore = LiteLLM_ManagedVectorStore(
|
||||
**_new_vector_store.model_dump()
|
||||
)
|
||||
|
||||
# Add vector store to registry
|
||||
if litellm.vector_store_registry is not None:
|
||||
litellm.vector_store_registry.add_vector_store_to_registry(
|
||||
vector_store=new_vector_store
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Vector store {vector_store.get('vector_store_id')} created successfully",
|
||||
"vector_store": new_vector_store,
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error creating vector store: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/vector_store/list",
|
||||
tags=["vector store management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=LiteLLM_ManagedVectorStoreListResponse,
|
||||
)
|
||||
async def list_vector_stores(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
page: int = 1,
|
||||
page_size: int = 100,
|
||||
):
|
||||
"""
|
||||
List all available vector stores with optional filtering and pagination.
|
||||
Combines both in-memory vector stores and those stored in the database.
|
||||
|
||||
Parameters:
|
||||
- page: int - Page number for pagination (default: 1)
|
||||
- page_size: int - Number of items per page (default: 100)
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
seen_vector_store_ids = set()
|
||||
|
||||
try:
|
||||
# Get in-memory vector stores
|
||||
in_memory_vector_stores: List[LiteLLM_ManagedVectorStore] = []
|
||||
if litellm.vector_store_registry is not None:
|
||||
in_memory_vector_stores = copy.deepcopy(
|
||||
litellm.vector_store_registry.vector_stores
|
||||
)
|
||||
|
||||
# Get vector stores from database
|
||||
vector_stores_from_db = await VectorStoreRegistry._get_vector_stores_from_db(
|
||||
prisma_client=prisma_client
|
||||
)
|
||||
|
||||
# Combine in-memory and database vector stores
|
||||
combined_vector_stores: List[LiteLLM_ManagedVectorStore] = []
|
||||
for vector_store in in_memory_vector_stores + vector_stores_from_db:
|
||||
vector_store_id = vector_store.get("vector_store_id", None)
|
||||
if vector_store_id not in seen_vector_store_ids:
|
||||
combined_vector_stores.append(vector_store)
|
||||
seen_vector_store_ids.add(vector_store_id)
|
||||
|
||||
total_count = len(combined_vector_stores)
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
# Format response using LiteLLM_ManagedVectorStoreListResponse
|
||||
response = LiteLLM_ManagedVectorStoreListResponse(
|
||||
object="list",
|
||||
data=combined_vector_stores,
|
||||
total_count=total_count,
|
||||
current_page=page,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error listing vector stores: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vector_store/delete",
|
||||
tags=["vector store management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_vector_store(
|
||||
data: VectorStoreDeleteRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete a vector store.
|
||||
|
||||
Parameters:
|
||||
- vector_store_id: str - ID of the vector store to delete
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Check if vector store exists
|
||||
existing_vector_store = (
|
||||
await prisma_client.db.litellm_managedvectorstorestable.find_unique(
|
||||
where={"vector_store_id": data.vector_store_id}
|
||||
)
|
||||
)
|
||||
if existing_vector_store is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Vector store with ID {data.vector_store_id} not found",
|
||||
)
|
||||
|
||||
# Delete vector store
|
||||
await prisma_client.db.litellm_managedvectorstorestable.delete(
|
||||
where={"vector_store_id": data.vector_store_id}
|
||||
)
|
||||
|
||||
# Delete vector store from registry
|
||||
if litellm.vector_store_registry is not None:
|
||||
litellm.vector_store_registry.delete_vector_store_from_registry(
|
||||
vector_store_id=data.vector_store_id
|
||||
)
|
||||
|
||||
return {"message": f"Vector store {data.vector_store_id} deleted successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vector_store/info",
|
||||
tags=["vector store management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_vector_store_info(
|
||||
data: VectorStoreInfoRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""Return a single vector store's details"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
vector_store = await prisma_client.db.litellm_managedvectorstorestable.find_unique(
|
||||
where={"vector_store_id": data.vector_store_id}
|
||||
)
|
||||
if vector_store is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Vector store with ID {data.vector_store_id} not found",
|
||||
)
|
||||
|
||||
vector_store_dict = vector_store.model_dump()
|
||||
return {"vector_store": vector_store_dict}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting vector store info: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vector_store/update",
|
||||
tags=["vector store management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_vector_store(
|
||||
data: VectorStoreUpdateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""Update vector store details"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
vector_store_id = update_data.pop("vector_store_id")
|
||||
if update_data.get("vector_store_metadata") is not None:
|
||||
update_data["vector_store_metadata"] = safe_dumps(update_data["vector_store_metadata"])
|
||||
|
||||
updated = await prisma_client.db.litellm_managedvectorstorestable.update(
|
||||
where={"vector_store_id": vector_store_id},
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
updated_vs = LiteLLM_ManagedVectorStore(**updated.model_dump())
|
||||
|
||||
if litellm.vector_store_registry is not None:
|
||||
litellm.vector_store_registry.update_vector_store_in_registry(
|
||||
vector_store_id=vector_store_id,
|
||||
updated_data=updated_vs,
|
||||
)
|
||||
|
||||
return {"vector_store": updated_vs}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating vector store: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
Reference in New Issue
Block a user