Added LiteLLM to the stack

This commit is contained in:
2025-08-18 09:40:50 +00:00
parent 0648c1968c
commit d220b04e32
2682 changed files with 533609 additions and 1 deletions

View File

@@ -0,0 +1,167 @@
"""
AUDIT LOGGING
All /audit logging endpoints. Attempting to write these as CRUD endpoints.
GET - /audit/{id} - Get audit log by id
GET - /audit - Get all audit logs
"""
from typing import Any, Dict, Optional
#### AUDIT LOGGING ####
from fastapi import APIRouter, Depends, HTTPException, Query
from litellm_enterprise.types.proxy.audit_logging_endpoints import (
AuditLogResponse,
PaginatedAuditLogResponse,
)
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter()
@router.get(
"/audit",
tags=["Audit Logging"],
dependencies=[Depends(user_api_key_auth)],
response_model=PaginatedAuditLogResponse,
)
async def get_audit_logs(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
# Filter parameters
changed_by: Optional[str] = Query(
None, description="Filter by user or system that performed the action"
),
changed_by_api_key: Optional[str] = Query(
None, description="Filter by API key hash that performed the action"
),
action: Optional[str] = Query(
None, description="Filter by action type (create, update, delete)"
),
table_name: Optional[str] = Query(
None, description="Filter by table name that was modified"
),
object_id: Optional[str] = Query(
None, description="Filter by ID of the object that was modified"
),
start_date: Optional[str] = Query(None, description="Filter logs after this date"),
end_date: Optional[str] = Query(None, description="Filter logs before this date"),
# Sorting parameters
sort_by: Optional[str] = Query(
None,
description="Column to sort by (e.g. 'updated_at', 'action', 'table_name')",
),
sort_order: str = Query("desc", description="Sort order ('asc' or 'desc')"),
):
"""
Get all audit logs with filtering and pagination.
Returns a paginated response of audit logs matching the specified filters.
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"message": CommonProxyErrors.db_not_connected_error.value},
)
# Build filter conditions
where_conditions: Dict[str, Any] = {}
if changed_by:
where_conditions["changed_by"] = changed_by
if changed_by_api_key:
where_conditions["changed_by_api_key"] = changed_by_api_key
if action:
where_conditions["action"] = action
if table_name:
where_conditions["table_name"] = table_name
if object_id:
where_conditions["object_id"] = object_id
if start_date or end_date:
date_filter = {}
if start_date:
date_filter["gte"] = start_date
if end_date:
date_filter["lte"] = end_date
where_conditions["updated_at"] = date_filter
# Build sort conditions
order_by = {}
if sort_by and isinstance(sort_by, str):
order_by[sort_by] = sort_order
elif sort_order and isinstance(sort_order, str):
order_by["updated_at"] = sort_order # Default sort by updated_at
# Get paginated results
audit_logs = await prisma_client.db.litellm_auditlog.find_many(
where=where_conditions,
order=order_by,
skip=(page - 1) * page_size,
take=page_size,
)
# Get total count for pagination
total_count = await prisma_client.db.litellm_auditlog.count(where=where_conditions)
total_pages = -(-total_count // page_size) # Ceiling division
# Return paginated response
return PaginatedAuditLogResponse(
audit_logs=[
AuditLogResponse(**audit_log.model_dump()) for audit_log in audit_logs
]
if audit_logs
else [],
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
@router.get(
"/audit/{id}",
tags=["Audit Logging"],
dependencies=[Depends(user_api_key_auth)],
response_model=AuditLogResponse,
responses={
404: {"description": "Audit log not found"},
500: {"description": "Database connection error"},
},
)
async def get_audit_log_by_id(
id: str, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)
):
"""
Get detailed information about a specific audit log entry by its ID.
Args:
id (str): The unique identifier of the audit log entry
Returns:
AuditLogResponse: Detailed information about the audit log entry
Raises:
HTTPException: If the audit log is not found or if there's a database connection error
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"message": CommonProxyErrors.db_not_connected_error.value},
)
# Get the audit log by ID
audit_log = await prisma_client.db.litellm_auditlog.find_unique(where={"id": id})
if audit_log is None:
raise HTTPException(
status_code=404, detail={"message": f"Audit log with ID {id} not found"}
)
# Convert to response model
return AuditLogResponse(**audit_log.model_dump())

View File

@@ -0,0 +1,10 @@
"""
Enterprise Authentication Module for LiteLLM Proxy
This module contains enterprise-specific authentication functionality,
including custom SSO handlers and advanced authentication features.
"""
from .custom_sso_handler import EnterpriseCustomSSOHandler
__all__ = ["EnterpriseCustomSSOHandler"]

View File

@@ -0,0 +1,86 @@
"""
Enterprise Custom SSO Handler for LiteLLM Proxy
This module contains enterprise-specific custom SSO authentication functionality
that allows users to implement their own SSO handling logic by providing custom
handlers that process incoming request headers and return OpenID objects.
Use this when you have an OAuth proxy in front of LiteLLM (where the OAuth proxy
has already authenticated the user) and you need to extract user information from
custom headers or other request attributes.
"""
from typing import TYPE_CHECKING, Dict, Optional, Union, cast
from fastapi import Request
from fastapi.responses import RedirectResponse
if TYPE_CHECKING:
from fastapi_sso.sso.base import OpenID
else:
from typing import Any as OpenID
from litellm.proxy.management_endpoints.types import CustomOpenID
class EnterpriseCustomSSOHandler:
"""
Enterprise Custom SSO Handler for LiteLLM Proxy
This class provides methods for handling custom SSO authentication flows
where users can implement their own authentication logic by processing
request headers and returning user information in OpenID format.
"""
@staticmethod
async def handle_custom_ui_sso_sign_in(
request: Request,
) -> RedirectResponse:
"""
Allow a user to execute their custom code to parse incoming request headers and return a OpenID object
Use this when you have an OAuth proxy in front of LiteLLM (where the OAuth proxy has already authenticated the user)
Args:
request: The FastAPI request object containing headers and other request data
Returns:
RedirectResponse: Redirect response that sends the user to the LiteLLM UI with authentication token
Raises:
ValueError: If custom_ui_sso_sign_in_handler is not configured
Example:
This method is typically called when a user has already been authenticated by an
external OAuth proxy and the proxy has added custom headers containing user information.
The custom handler extracts this information and converts it to an OpenID object.
"""
from fastapi_sso.sso.base import OpenID
from litellm.integrations.custom_sso_handler import CustomSSOLoginHandler
from litellm.proxy.proxy_server import (
CommonProxyErrors,
premium_user,
user_custom_ui_sso_sign_in_handler,
)
if premium_user is not True:
raise ValueError(CommonProxyErrors.not_premium_user.value)
if user_custom_ui_sso_sign_in_handler is None:
raise ValueError("custom_ui_sso_sign_in_handler is not configured. Please set it in general_settings.")
custom_sso_login_handler = cast(CustomSSOLoginHandler, user_custom_ui_sso_sign_in_handler)
openid_response: OpenID = await custom_sso_login_handler.handle_custom_ui_sso_sign_in(
request=request,
)
# Import here to avoid circular imports
from litellm.proxy.management_endpoints.ui_sso import SSOAuthenticationHandler
return await SSOAuthenticationHandler.get_redirect_response_from_openid(
result=openid_response,
request=request,
received_response=None,
generic_client_id=None,
ui_access_mode=None,
)

View File

@@ -0,0 +1,66 @@
import os
from fastapi import HTTPException, status
class EnterpriseRouteChecks:
@staticmethod
def is_llm_api_route_disabled() -> bool:
"""
Check if llm api route is disabled
"""
from litellm.proxy._types import CommonProxyErrors
from litellm.proxy.proxy_server import premium_user
from litellm.secret_managers.main import get_secret_bool
## Check if DISABLE_LLM_API_ENDPOINTS is set
if "DISABLE_LLM_API_ENDPOINTS" in os.environ:
if not premium_user:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"🚨🚨🚨 DISABLING LLM API ENDPOINTS is an Enterprise feature\n🚨 {CommonProxyErrors.not_premium_user.value}",
)
return get_secret_bool("DISABLE_LLM_API_ENDPOINTS") is True
@staticmethod
def is_management_routes_disabled() -> bool:
"""
Check if management route is disabled
"""
from litellm.proxy._types import CommonProxyErrors
from litellm.proxy.proxy_server import premium_user
from litellm.secret_managers.main import get_secret_bool
if "DISABLE_ADMIN_ENDPOINTS" in os.environ:
if not premium_user:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"🚨🚨🚨 DISABLING LLM API ENDPOINTS is an Enterprise feature\n🚨 {CommonProxyErrors.not_premium_user.value}",
)
return get_secret_bool("DISABLE_ADMIN_ENDPOINTS") is True
@staticmethod
def should_call_route(route: str):
"""
Check if management route is disabled and raise exception
"""
from litellm.proxy.auth.route_checks import RouteChecks
if (
RouteChecks.is_management_route(route=route)
and EnterpriseRouteChecks.is_management_routes_disabled()
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Management routes are disabled for this instance.",
)
elif (
RouteChecks.is_llm_api_route(route=route)
and EnterpriseRouteChecks.is_llm_api_route_disabled()
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="LLM API routes are disabled for this instance.",
)

View File

@@ -0,0 +1,35 @@
from typing import Any, Optional
from fastapi import Request
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ProxyException, UserAPIKeyAuth
async def enterprise_custom_auth(
request: Request, api_key: str, user_custom_auth: Optional[Any]
) -> Optional[UserAPIKeyAuth]:
from litellm_enterprise.proxy.proxy_server import custom_auth_settings
if user_custom_auth is None:
return None
if custom_auth_settings is None:
return await user_custom_auth(request, api_key)
if custom_auth_settings["mode"] == "on":
return await user_custom_auth(request, api_key)
elif custom_auth_settings["mode"] == "off":
return None
elif custom_auth_settings["mode"] == "auto":
try:
return await user_custom_auth(request, api_key)
except ProxyException as e:
raise e
except Exception as e:
verbose_proxy_logger.debug(
f"Error in custom auth, checking litellm auth: {e}"
)
return None
else:
raise ValueError(f"Invalid mode: {custom_auth_settings['mode']}")

View File

@@ -0,0 +1,188 @@
"""
Polls LiteLLM_ManagedObjectTable to check if the batch job is complete, and if the cost has been tracked.
"""
import uuid
from datetime import datetime
from typing import TYPE_CHECKING, Optional, cast
from litellm._logging import verbose_proxy_logger
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient, ProxyLogging
from litellm.router import Router
class CheckBatchCost:
def __init__(
self,
proxy_logging_obj: "ProxyLogging",
prisma_client: "PrismaClient",
llm_router: "Router",
):
from litellm.proxy.utils import PrismaClient, ProxyLogging
from litellm.router import Router
self.proxy_logging_obj: ProxyLogging = proxy_logging_obj
self.prisma_client: PrismaClient = prisma_client
self.llm_router: Router = llm_router
async def check_batch_cost(self):
"""
Check if the batch JOB has been tracked.
- get all status="validating" and file_purpose="batch" jobs
- check if batch is now complete
- if not, return False
- if so, return True
"""
from litellm_enterprise.proxy.hooks.managed_files import (
_PROXY_LiteLLMManagedFiles,
)
from litellm.batches.batch_utils import (
_get_file_content_as_dictionary,
calculate_batch_cost_and_usage,
)
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.proxy.openai_files_endpoints.common_utils import (
_is_base64_encoded_unified_file_id,
get_batch_id_from_unified_batch_id,
get_model_id_from_unified_batch_id,
)
jobs = await self.prisma_client.db.litellm_managedobjecttable.find_many(
where={
"status": "validating",
"file_purpose": "batch",
}
)
completed_jobs = []
for job in jobs:
# get the model from the job
unified_object_id = job.unified_object_id
decoded_unified_object_id = _is_base64_encoded_unified_file_id(
unified_object_id
)
if not decoded_unified_object_id:
verbose_proxy_logger.info(
f"Skipping job {unified_object_id} because it is not a valid unified object id"
)
continue
else:
unified_object_id = decoded_unified_object_id
model_id = get_model_id_from_unified_batch_id(unified_object_id)
batch_id = get_batch_id_from_unified_batch_id(unified_object_id)
if model_id is None:
verbose_proxy_logger.info(
f"Skipping job {unified_object_id} because it is not a valid model id"
)
continue
verbose_proxy_logger.info(
f"Querying model ID: {model_id} for cost and usage of batch ID: {batch_id}"
)
try:
response = await self.llm_router.aretrieve_batch(
model=model_id,
batch_id=batch_id,
litellm_metadata={
"user_api_key_user_id": job.created_by or "default-user-id",
"batch_ignore_default_logging": True,
},
)
except Exception as e:
verbose_proxy_logger.info(
f"Skipping job {unified_object_id} because of error querying model ID: {model_id} for cost and usage of batch ID: {batch_id}: {e}"
)
continue
## RETRIEVE THE BATCH JOB OUTPUT FILE
managed_files_obj = cast(
Optional[_PROXY_LiteLLMManagedFiles],
self.proxy_logging_obj.get_proxy_hook("managed_files"),
)
if (
response.status == "completed"
and response.output_file_id is not None
and managed_files_obj is not None
):
verbose_proxy_logger.info(
f"Batch ID: {batch_id} is complete, tracking cost and usage"
)
# track cost
model_file_id_mapping = {
response.output_file_id: {model_id: response.output_file_id}
}
_file_content = await managed_files_obj.afile_content(
file_id=response.output_file_id,
litellm_parent_otel_span=None,
llm_router=self.llm_router,
model_file_id_mapping=model_file_id_mapping,
)
file_content_as_dict = _get_file_content_as_dictionary(
_file_content.content
)
deployment_info = self.llm_router.get_deployment(model_id=model_id)
if deployment_info is None:
verbose_proxy_logger.info(
f"Skipping job {unified_object_id} because it is not a valid deployment info"
)
continue
custom_llm_provider = deployment_info.litellm_params.custom_llm_provider
litellm_model_name = deployment_info.litellm_params.model
_, llm_provider, _, _ = get_llm_provider(
model=litellm_model_name,
custom_llm_provider=custom_llm_provider,
)
batch_cost, batch_usage, batch_models = (
await calculate_batch_cost_and_usage(
file_content_dictionary=file_content_as_dict,
custom_llm_provider=llm_provider, # type: ignore
)
)
logging_obj = LiteLLMLogging(
model=batch_models[0],
messages=[{"role": "user", "content": "<retrieve_batch>"}],
stream=False,
call_type="aretrieve_batch",
start_time=datetime.now(),
litellm_call_id=str(uuid.uuid4()),
function_id=str(uuid.uuid4()),
)
logging_obj.update_environment_variables(
litellm_params={
"metadata": {
"user_api_key_user_id": job.created_by or "default-user-id",
}
},
optional_params={},
)
await logging_obj.async_success_handler(
result=response,
batch_cost=batch_cost,
batch_usage=batch_usage,
batch_models=batch_models,
)
# mark the job as complete
completed_jobs.append(job)
if len(completed_jobs) > 0:
# mark the jobs as complete
await self.prisma_client.db.litellm_managedobjecttable.update_many(
where={"id": {"in": [job.id for job in completed_jobs]}},
data={"status": "complete"},
)

View File

@@ -0,0 +1,30 @@
from fastapi import APIRouter
from fastapi.responses import Response
from litellm_enterprise.enterprise_callbacks.send_emails.endpoints import (
router as email_events_router,
)
from .audit_logging_endpoints import router as audit_logging_router
from .guardrails.endpoints import router as guardrails_router
from .management_endpoints import management_endpoints_router
from .utils import _should_block_robots
from .vector_stores.endpoints import router as vector_stores_router
router = APIRouter()
router.include_router(vector_stores_router)
router.include_router(guardrails_router)
router.include_router(email_events_router)
router.include_router(audit_logging_router)
router.include_router(management_endpoints_router)
@router.get("/robots.txt")
async def get_robots():
"""
Block all web crawlers from indexing the proxy server endpoints
This is useful for ensuring that the API endpoints aren't indexed by search engines
"""
if _should_block_robots():
return Response(content="User-agent: *\nDisallow: /", media_type="text/plain")
else:
return Response(status_code=404)

View File

@@ -0,0 +1,41 @@
"""
Enterprise Guardrail Routes on LiteLLM Proxy
To see all free guardrails see litellm/proxy/guardrails/*
Exposed Routes:
- /mask_pii
"""
from typing import Optional
from fastapi import APIRouter, Depends
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.guardrails.guardrail_endpoints import GUARDRAIL_REGISTRY
from litellm.types.guardrails import ApplyGuardrailRequest, ApplyGuardrailResponse
router = APIRouter(tags=["guardrails"], prefix="/guardrails")
@router.post("/apply_guardrail", response_model=ApplyGuardrailResponse)
async def apply_guardrail(
request: ApplyGuardrailRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Mask PII from a given text, requires a guardrail to be added to litellm.
"""
active_guardrail: Optional[
CustomGuardrail
] = GUARDRAIL_REGISTRY.get_initialized_guardrail_callback(
guardrail_name=request.guardrail_name
)
if active_guardrail is None:
raise Exception(f"Guardrail {request.guardrail_name} not found")
return await active_guardrail.apply_guardrail(
text=request.text, language=request.language, entities=request.entities
)

View File

@@ -0,0 +1,830 @@
# What is this?
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
import asyncio
import base64
import json
import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
from fastapi import HTTPException
from litellm import Router, verbose_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.llms.base_llm.files.transformation import BaseFileEndpoints
from litellm.proxy._types import (
CallTypes,
LiteLLM_ManagedFileTable,
LiteLLM_ManagedObjectTable,
UserAPIKeyAuth,
)
from litellm.proxy.openai_files_endpoints.common_utils import (
_is_base64_encoded_unified_file_id,
convert_b64_uid_to_unified_uid,
get_batch_id_from_unified_batch_id,
get_model_id_from_unified_batch_id,
)
from litellm.types.llms.openai import (
AllMessageValues,
AsyncCursorPage,
ChatCompletionFileObject,
CreateFileRequest,
FileObject,
OpenAIFileObject,
OpenAIFilesPurpose,
)
from litellm.types.utils import (
LiteLLMBatch,
LiteLLMFineTuningJob,
LLMResponseTypes,
SpecialEnums,
)
if TYPE_CHECKING:
from litellm.types.llms.openai import HttpxBinaryResponseContent
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
from litellm.proxy.utils import PrismaClient as _PrismaClient
Span = Union[_Span, Any]
InternalUsageCache = _InternalUsageCache
PrismaClient = _PrismaClient
else:
Span = Any
InternalUsageCache = Any
PrismaClient = Any
class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
# Class variables or attributes
def __init__(
self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient
):
self.internal_usage_cache = internal_usage_cache
self.prisma_client = prisma_client
async def store_unified_file_id(
self,
file_id: str,
file_object: Optional[OpenAIFileObject],
litellm_parent_otel_span: Optional[Span],
model_mappings: Dict[str, str],
user_api_key_dict: UserAPIKeyAuth,
) -> None:
verbose_logger.info(
f"Storing LiteLLM Managed File object with id={file_id} in cache"
)
if file_object is not None:
litellm_managed_file_object = LiteLLM_ManagedFileTable(
unified_file_id=file_id,
file_object=file_object,
model_mappings=model_mappings,
flat_model_file_ids=list(model_mappings.values()),
created_by=user_api_key_dict.user_id,
updated_by=user_api_key_dict.user_id,
)
await self.internal_usage_cache.async_set_cache(
key=file_id,
value=litellm_managed_file_object.model_dump(),
litellm_parent_otel_span=litellm_parent_otel_span,
)
## STORE MODEL MAPPINGS IN DB
db_data = {
"unified_file_id": file_id,
"model_mappings": json.dumps(model_mappings),
"flat_model_file_ids": list(model_mappings.values()),
"created_by": user_api_key_dict.user_id,
"updated_by": user_api_key_dict.user_id,
}
if file_object is not None:
db_data["file_object"] = file_object.model_dump_json()
result = await self.prisma_client.db.litellm_managedfiletable.create(
data=db_data
)
verbose_logger.debug(
f"LiteLLM Managed File object with id={file_id} stored in db: {result}"
)
async def store_unified_object_id(
self,
unified_object_id: str,
file_object: Union[LiteLLMBatch, LiteLLMFineTuningJob],
litellm_parent_otel_span: Optional[Span],
model_object_id: str,
file_purpose: Literal["batch", "fine-tune"],
user_api_key_dict: UserAPIKeyAuth,
) -> None:
verbose_logger.info(
f"Storing LiteLLM Managed {file_purpose} object with id={unified_object_id} in cache"
)
litellm_managed_object = LiteLLM_ManagedObjectTable(
unified_object_id=unified_object_id,
model_object_id=model_object_id,
file_purpose=file_purpose,
file_object=file_object,
)
await self.internal_usage_cache.async_set_cache(
key=unified_object_id,
value=litellm_managed_object.model_dump(),
litellm_parent_otel_span=litellm_parent_otel_span,
)
await self.prisma_client.db.litellm_managedobjecttable.upsert(
where={"unified_object_id": unified_object_id},
data={
"create": {
"unified_object_id": unified_object_id,
"file_object": file_object.model_dump_json(),
"model_object_id": model_object_id,
"file_purpose": file_purpose,
"created_by": user_api_key_dict.user_id,
"updated_by": user_api_key_dict.user_id,
"status": file_object.status,
},
"update": {}, # don't do anything if it already exists
}
)
async def get_unified_file_id(
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
) -> Optional[LiteLLM_ManagedFileTable]:
## CHECK CACHE
result = cast(
Optional[dict],
await self.internal_usage_cache.async_get_cache(
key=file_id,
litellm_parent_otel_span=litellm_parent_otel_span,
),
)
if result:
return LiteLLM_ManagedFileTable(**result)
## CHECK DB
db_object = await self.prisma_client.db.litellm_managedfiletable.find_first(
where={"unified_file_id": file_id}
)
if db_object:
return LiteLLM_ManagedFileTable(**db_object.model_dump())
return None
async def delete_unified_file_id(
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
) -> OpenAIFileObject:
## get old value
initial_value = await self.prisma_client.db.litellm_managedfiletable.find_first(
where={"unified_file_id": file_id}
)
if initial_value is None:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
## delete old value
await self.internal_usage_cache.async_set_cache(
key=file_id,
value=None,
litellm_parent_otel_span=litellm_parent_otel_span,
)
await self.prisma_client.db.litellm_managedfiletable.delete(
where={"unified_file_id": file_id}
)
return initial_value.file_object
async def can_user_call_unified_file_id(
self, unified_file_id: str, user_api_key_dict: UserAPIKeyAuth
) -> bool:
## check if the user has access to the unified file id
user_id = user_api_key_dict.user_id
managed_file = await self.prisma_client.db.litellm_managedfiletable.find_first(
where={"unified_file_id": unified_file_id}
)
if managed_file:
return managed_file.created_by == user_id
return False
async def can_user_call_unified_object_id(
self, unified_object_id: str, user_api_key_dict: UserAPIKeyAuth
) -> bool:
## check if the user has access to the unified object id
## check if the user has access to the unified object id
user_id = user_api_key_dict.user_id
managed_object = (
await self.prisma_client.db.litellm_managedobjecttable.find_first(
where={"unified_object_id": unified_object_id}
)
)
if managed_object:
return managed_object.created_by == user_id
return False
async def get_user_created_file_ids(
self, user_api_key_dict: UserAPIKeyAuth, model_object_ids: List[str]
) -> List[OpenAIFileObject]:
"""
Get all file ids created by the user for a list of model object ids
Returns:
- List of OpenAIFileObject's
"""
file_ids = await self.prisma_client.db.litellm_managedfiletable.find_many(
where={
"created_by": user_api_key_dict.user_id,
"flat_model_file_ids": {"hasSome": model_object_ids},
}
)
return [OpenAIFileObject(**file_object.file_object) for file_object in file_ids]
async def check_managed_file_id_access(
self, data: Dict, user_api_key_dict: UserAPIKeyAuth
) -> bool:
retrieve_file_id = cast(Optional[str], data.get("file_id"))
potential_file_id = (
_is_base64_encoded_unified_file_id(retrieve_file_id)
if retrieve_file_id
else False
)
if potential_file_id and retrieve_file_id:
if await self.can_user_call_unified_file_id(
retrieve_file_id, user_api_key_dict
):
return True
else:
raise HTTPException(
status_code=403,
detail=f"User {user_api_key_dict.user_id} does not have access to the file {retrieve_file_id}",
)
return False
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: Dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
"acreate_batch",
"aretrieve_batch",
"acreate_file",
"afile_list",
"afile_delete",
"afile_content",
"acreate_fine_tuning_job",
"aretrieve_fine_tuning_job",
"alist_fine_tuning_jobs",
"acancel_fine_tuning_job",
"mcp_call",
],
) -> Union[Exception, str, Dict, None]:
"""
- Detect litellm_proxy/ file_id
- add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}}
"""
### HANDLE FILE ACCESS ### - ensure user has access to the file
if (
call_type == CallTypes.afile_content.value
or call_type == CallTypes.afile_delete.value
):
await self.check_managed_file_id_access(data, user_api_key_dict)
### HANDLE TRANSFORMATIONS ###
if call_type == CallTypes.completion.value:
messages = data.get("messages")
if messages:
file_ids = self.get_file_ids_from_messages(messages)
if file_ids:
model_file_id_mapping = await self.get_model_file_id_mapping(
file_ids, user_api_key_dict.parent_otel_span
)
data["model_file_id_mapping"] = model_file_id_mapping
elif call_type == CallTypes.afile_content.value:
retrieve_file_id = cast(Optional[str], data.get("file_id"))
potential_file_id = (
_is_base64_encoded_unified_file_id(retrieve_file_id)
if retrieve_file_id
else False
)
if potential_file_id:
model_id = self.get_model_id_from_unified_file_id(potential_file_id)
if model_id:
data["model"] = model_id
data["file_id"] = self.get_output_file_id_from_unified_file_id(
potential_file_id
)
elif call_type == CallTypes.acreate_batch.value:
input_file_id = cast(Optional[str], data.get("input_file_id"))
if input_file_id:
model_file_id_mapping = await self.get_model_file_id_mapping(
[input_file_id], user_api_key_dict.parent_otel_span
)
data["model_file_id_mapping"] = model_file_id_mapping
elif (
call_type == CallTypes.aretrieve_batch.value
or call_type == CallTypes.acancel_fine_tuning_job.value
or call_type == CallTypes.aretrieve_fine_tuning_job.value
):
accessor_key: Optional[str] = None
retrieve_object_id: Optional[str] = None
if call_type == CallTypes.aretrieve_batch.value:
accessor_key = "batch_id"
elif (
call_type == CallTypes.acancel_fine_tuning_job.value
or call_type == CallTypes.aretrieve_fine_tuning_job.value
):
accessor_key = "fine_tuning_job_id"
if accessor_key:
retrieve_object_id = cast(Optional[str], data.get(accessor_key))
potential_llm_object_id = (
_is_base64_encoded_unified_file_id(retrieve_object_id)
if retrieve_object_id
else False
)
if potential_llm_object_id and retrieve_object_id:
## VALIDATE USER HAS ACCESS TO THE OBJECT ##
if not await self.can_user_call_unified_object_id(
retrieve_object_id, user_api_key_dict
):
raise HTTPException(
status_code=403,
detail=f"User {user_api_key_dict.user_id} does not have access to the object {retrieve_object_id}",
)
## for managed batch id - get the model id
potential_model_id = get_model_id_from_unified_batch_id(
potential_llm_object_id
)
if potential_model_id is None:
raise Exception(
f"LiteLLM Managed {accessor_key} with id={retrieve_object_id} is invalid - does not contain encoded model_id."
)
data["model"] = potential_model_id
data[accessor_key] = get_batch_id_from_unified_batch_id(
potential_llm_object_id
)
elif call_type == CallTypes.acreate_fine_tuning_job.value:
input_file_id = cast(Optional[str], data.get("training_file"))
if input_file_id:
model_file_id_mapping = await self.get_model_file_id_mapping(
[input_file_id], user_api_key_dict.parent_otel_span
)
return data
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[Dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[Dict]:
if request_kwargs is None:
return healthy_deployments
input_file_id = cast(Optional[str], request_kwargs.get("input_file_id"))
model_file_id_mapping = cast(
Optional[Dict[str, Dict[str, str]]],
request_kwargs.get("model_file_id_mapping"),
)
allowed_model_ids = []
if input_file_id and model_file_id_mapping:
model_id_dict = model_file_id_mapping.get(input_file_id, {})
allowed_model_ids = list(model_id_dict.keys())
if len(allowed_model_ids) == 0:
return healthy_deployments
return [
deployment
for deployment in healthy_deployments
if deployment.get("model_info", {}).get("id") in allowed_model_ids
]
async def async_pre_call_deployment_hook(
self, kwargs: Dict[str, Any], call_type: Optional[CallTypes]
) -> Optional[dict]:
"""
Allow modifying the request just before it's sent to the deployment.
"""
accessor_key: Optional[str] = None
if call_type and call_type == CallTypes.acreate_batch:
accessor_key = "input_file_id"
elif call_type and call_type == CallTypes.acreate_fine_tuning_job:
accessor_key = "training_file"
else:
return kwargs
if accessor_key:
input_file_id = cast(Optional[str], kwargs.get(accessor_key))
model_file_id_mapping = cast(
Optional[Dict[str, Dict[str, str]]], kwargs.get("model_file_id_mapping")
)
model_id = cast(Optional[str], kwargs.get("model_info", {}).get("id", None))
mapped_file_id: Optional[str] = None
if input_file_id and model_file_id_mapping and model_id:
mapped_file_id = model_file_id_mapping.get(input_file_id, {}).get(
model_id, None
)
if mapped_file_id:
kwargs[accessor_key] = mapped_file_id
return kwargs
def get_file_ids_from_messages(self, messages: List[AllMessageValues]) -> List[str]:
"""
Gets file ids from messages
"""
file_ids = []
for message in messages:
if message.get("role") == "user":
content = message.get("content")
if content:
if isinstance(content, str):
continue
for c in content:
if c["type"] == "file":
file_object = cast(ChatCompletionFileObject, c)
file_object_file_field = file_object["file"]
file_id = file_object_file_field.get("file_id")
if file_id:
file_ids.append(file_id)
return file_ids
async def get_model_file_id_mapping(
self, file_ids: List[str], litellm_parent_otel_span: Span
) -> dict:
"""
Get model-specific file IDs for a list of proxy file IDs.
Returns a dictionary mapping litellm_proxy/ file_id -> model_id -> model_file_id
1. Get all the litellm_proxy/ file_ids from the messages
2. For each file_id, search for cache keys matching the pattern file_id:*
3. Return a dictionary of mappings of litellm_proxy/ file_id -> model_id -> model_file_id
Example:
{
"litellm_proxy/file_id": {
"model_id": "model_file_id"
}
}
"""
file_id_mapping: Dict[str, Dict[str, str]] = {}
litellm_managed_file_ids = []
for file_id in file_ids:
## CHECK IF FILE ID IS MANAGED BY LITELM
is_base64_unified_file_id = _is_base64_encoded_unified_file_id(file_id)
if is_base64_unified_file_id:
litellm_managed_file_ids.append(file_id)
if litellm_managed_file_ids:
# Get all cache keys matching the pattern file_id:*
for file_id in litellm_managed_file_ids:
# Search for any cache key starting with this file_id
unified_file_object = await self.get_unified_file_id(
file_id, litellm_parent_otel_span
)
if unified_file_object:
file_id_mapping[file_id] = unified_file_object.model_mappings
return file_id_mapping
async def create_file_for_each_model(
self,
llm_router: Optional[Router],
_create_file_request: CreateFileRequest,
target_model_names_list: List[str],
litellm_parent_otel_span: Span,
) -> List[OpenAIFileObject]:
if llm_router is None:
raise Exception("LLM Router not initialized. Ensure models added to proxy.")
responses = []
for model in target_model_names_list:
individual_response = await llm_router.acreate_file(
model=model, **_create_file_request
)
responses.append(individual_response)
return responses
async def acreate_file(
self,
create_file_request: CreateFileRequest,
llm_router: Router,
target_model_names_list: List[str],
litellm_parent_otel_span: Span,
user_api_key_dict: UserAPIKeyAuth,
) -> OpenAIFileObject:
responses = await self.create_file_for_each_model(
llm_router=llm_router,
_create_file_request=create_file_request,
target_model_names_list=target_model_names_list,
litellm_parent_otel_span=litellm_parent_otel_span,
)
response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id(
file_objects=responses,
create_file_request=create_file_request,
internal_usage_cache=self.internal_usage_cache,
litellm_parent_otel_span=litellm_parent_otel_span,
target_model_names_list=target_model_names_list,
)
## STORE MODEL MAPPINGS IN DB
model_mappings: Dict[str, str] = {}
for file_object in responses:
model_file_id_mapping = file_object._hidden_params.get(
"model_file_id_mapping"
)
if model_file_id_mapping and isinstance(model_file_id_mapping, dict):
model_mappings.update(model_file_id_mapping)
await self.store_unified_file_id(
file_id=response.id,
file_object=response,
litellm_parent_otel_span=litellm_parent_otel_span,
model_mappings=model_mappings,
user_api_key_dict=user_api_key_dict,
)
return response
@staticmethod
async def return_unified_file_id(
file_objects: List[OpenAIFileObject],
create_file_request: CreateFileRequest,
internal_usage_cache: InternalUsageCache,
litellm_parent_otel_span: Span,
target_model_names_list: List[str],
) -> OpenAIFileObject:
## GET THE FILE TYPE FROM THE CREATE FILE REQUEST
file_data = extract_file_data(create_file_request["file"])
file_type = file_data["content_type"]
output_file_id = file_objects[0].id
model_id = file_objects[0]._hidden_params.get("model_id")
unified_file_id = SpecialEnums.LITELLM_MANAGED_FILE_COMPLETE_STR.value.format(
file_type,
str(uuid.uuid4()),
",".join(target_model_names_list),
output_file_id,
model_id,
)
# Convert to URL-safe base64 and strip padding
base64_unified_file_id = (
base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=")
)
## CREATE RESPONSE OBJECT
response = OpenAIFileObject(
id=base64_unified_file_id,
object="file",
purpose=create_file_request["purpose"],
created_at=file_objects[0].created_at,
bytes=file_objects[0].bytes,
filename=file_objects[0].filename,
status="uploaded",
)
return response
def get_unified_generic_response_id(
self, model_id: str, generic_response_id: str
) -> str:
unified_generic_response_id = (
SpecialEnums.LITELLM_MANAGED_GENERIC_RESPONSE_COMPLETE_STR.value.format(
model_id, generic_response_id
)
)
return (
base64.urlsafe_b64encode(unified_generic_response_id.encode())
.decode()
.rstrip("=")
)
def get_unified_batch_id(self, batch_id: str, model_id: str) -> str:
unified_batch_id = SpecialEnums.LITELLM_MANAGED_BATCH_COMPLETE_STR.value.format(
model_id, batch_id
)
return base64.urlsafe_b64encode(unified_batch_id.encode()).decode().rstrip("=")
def get_unified_output_file_id(
self, output_file_id: str, model_id: str, model_name: Optional[str]
) -> str:
unified_output_file_id = (
SpecialEnums.LITELLM_MANAGED_FILE_COMPLETE_STR.value.format(
"application/json",
str(uuid.uuid4()),
model_name or "",
output_file_id,
model_id,
)
)
return (
base64.urlsafe_b64encode(unified_output_file_id.encode())
.decode()
.rstrip("=")
)
def get_model_id_from_unified_file_id(self, file_id: str) -> str:
return file_id.split("llm_output_file_model_id,")[1].split(";")[0]
def get_output_file_id_from_unified_file_id(self, file_id: str) -> str:
return file_id.split("llm_output_file_id,")[1].split(";")[0]
async def async_post_call_success_hook(
self, data: Dict, user_api_key_dict: UserAPIKeyAuth, response: LLMResponseTypes
) -> Any:
if isinstance(response, LiteLLMBatch):
## Check if unified_file_id is in the response
unified_file_id = response._hidden_params.get(
"unified_file_id"
) # managed file id
unified_batch_id = response._hidden_params.get(
"unified_batch_id"
) # managed batch id
model_id = cast(Optional[str], response._hidden_params.get("model_id"))
model_name = cast(Optional[str], response._hidden_params.get("model_name"))
original_response_id = response.id
if (unified_batch_id or unified_file_id) and model_id:
response.id = self.get_unified_batch_id(
batch_id=response.id, model_id=model_id
)
if (
response.output_file_id and model_id
): # return a file id with the model_id and output_file_id
original_output_file_id = response.output_file_id
response.output_file_id = self.get_unified_output_file_id(
output_file_id=response.output_file_id,
model_id=model_id,
model_name=model_name,
)
await self.store_unified_file_id( # need to store otherwise any retrieve call will fail
file_id=response.output_file_id,
file_object=None,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
model_mappings={model_id: original_output_file_id},
user_api_key_dict=user_api_key_dict,
)
asyncio.create_task(
self.store_unified_object_id(
unified_object_id=response.id,
file_object=response,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
model_object_id=original_response_id,
file_purpose="batch",
user_api_key_dict=user_api_key_dict,
)
)
elif isinstance(response, LiteLLMFineTuningJob):
## Check if unified_file_id is in the response
unified_file_id = response._hidden_params.get(
"unified_file_id"
) # managed file id
unified_finetuning_job_id = response._hidden_params.get(
"unified_finetuning_job_id"
) # managed finetuning job id
model_id = cast(Optional[str], response._hidden_params.get("model_id"))
model_name = cast(Optional[str], response._hidden_params.get("model_name"))
original_response_id = response.id
if (unified_file_id or unified_finetuning_job_id) and model_id:
response.id = self.get_unified_generic_response_id(
model_id=model_id, generic_response_id=response.id
)
asyncio.create_task(
self.store_unified_object_id(
unified_object_id=response.id,
file_object=response,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
model_object_id=original_response_id,
file_purpose="fine-tune",
user_api_key_dict=user_api_key_dict,
)
)
elif isinstance(response, AsyncCursorPage):
"""
For listing files, filter for the ones created by the user
"""
## check if file object
if hasattr(response, "data") and isinstance(response.data, list):
if all(
isinstance(file_object, FileObject) for file_object in response.data
):
## Get all file id's
## Check which file id's were created by the user
## Filter the response to only include the files created by the user
## Return the filtered response
file_ids = [
file_object.id
for file_object in cast(List[FileObject], response.data) # type: ignore
]
user_created_file_ids = await self.get_user_created_file_ids(
user_api_key_dict, file_ids
)
## Filter the response to only include the files created by the user
response.data = user_created_file_ids # type: ignore
return response
return response
return response
async def afile_retrieve(
self, file_id: str, litellm_parent_otel_span: Optional[Span]
) -> OpenAIFileObject:
stored_file_object = await self.get_unified_file_id(
file_id, litellm_parent_otel_span
)
if stored_file_object:
return stored_file_object.file_object
else:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
async def afile_list(
self,
purpose: Optional[OpenAIFilesPurpose],
litellm_parent_otel_span: Optional[Span],
**data: Dict,
) -> List[OpenAIFileObject]:
"""Handled in files_endpoints.py"""
return []
async def afile_delete(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
llm_router: Router,
**data: Dict,
) -> OpenAIFileObject:
file_id = convert_b64_uid_to_unified_uid(file_id)
model_file_id_mapping = await self.get_model_file_id_mapping(
[file_id], litellm_parent_otel_span
)
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
if specific_model_file_id_mapping:
for model_id, file_id in specific_model_file_id_mapping.items():
await llm_router.afile_delete(model=model_id, file_id=file_id, **data) # type: ignore
stored_file_object = await self.delete_unified_file_id(
file_id, litellm_parent_otel_span
)
if stored_file_object:
return stored_file_object
else:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
async def afile_content(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
llm_router: Router,
**data: Dict,
) -> "HttpxBinaryResponseContent":
"""
Get the content of a file from first model that has it
"""
model_file_id_mapping = data.pop("model_file_id_mapping", None)
model_file_id_mapping = (
model_file_id_mapping
or await self.get_model_file_id_mapping([file_id], litellm_parent_otel_span)
)
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
if specific_model_file_id_mapping:
exception_dict = {}
for model_id, file_id in specific_model_file_id_mapping.items():
try:
return await llm_router.afile_content(model=model_id, file_id=file_id, **data) # type: ignore
except Exception as e:
exception_dict[model_id] = str(e)
raise Exception(
f"LiteLLM Managed File object with id={file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}"
)
else:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")

View File

@@ -0,0 +1,8 @@
from fastapi import APIRouter
from .internal_user_endpoints import router as internal_user_endpoints_router
management_endpoints_router = APIRouter()
management_endpoints_router.include_router(internal_user_endpoints_router)
__all__ = ["management_endpoints_router"]

View File

@@ -0,0 +1,67 @@
"""
Enterprise internal user management endpoints
"""
from fastapi import APIRouter, Depends, HTTPException
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.management_endpoints.internal_user_endpoints import user_api_key_auth
router = APIRouter()
@router.get(
"/user/available_users",
tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)],
)
async def available_enterprise_users(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
For keys with `max_users` set, return the list of users that are allowed to use the key.
"""
from litellm.proxy._types import CommonProxyErrors
from litellm.proxy.proxy_server import (
premium_user,
premium_user_data,
prisma_client,
)
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if premium_user is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.not_premium_user.value}
)
# Count number of rows in LiteLLM_UserTable
user_count = await prisma_client.db.litellm_usertable.count()
team_count = await prisma_client.db.litellm_teamtable.count()
if (
not premium_user_data
or premium_user_data is not None
and "max_users" not in premium_user_data
):
max_users = None
else:
max_users = premium_user_data.get("max_users")
if premium_user_data and "max_teams" in premium_user_data:
max_teams = premium_user_data.get("max_teams")
else:
max_teams = None
return {
"total_users": max_users,
"total_teams": max_teams,
"total_users_used": user_count,
"total_teams_used": team_count,
"total_teams_remaining": (max_teams - team_count if max_teams else None),
"total_users_remaining": (max_users - user_count if max_users else None),
}

View File

@@ -0,0 +1,30 @@
from typing import Optional
from litellm.proxy._types import GenerateKeyRequest, LiteLLM_TeamTable
def add_team_member_key_duration(
team_table: Optional[LiteLLM_TeamTable],
data: GenerateKeyRequest,
) -> GenerateKeyRequest:
if team_table is None:
return data
if data.user_id is None: # only apply for team member keys, not service accounts
return data
if (
team_table.metadata is not None
and team_table.metadata.get("team_member_key_duration") is not None
):
data.duration = team_table.metadata["team_member_key_duration"]
return data
def apply_enterprise_key_management_params(
data: GenerateKeyRequest,
team_table: Optional[LiteLLM_TeamTable],
) -> GenerateKeyRequest:
data = add_team_member_key_duration(team_table, data)
return data

View File

@@ -0,0 +1,34 @@
import os
from typing import Optional
from litellm_enterprise.types.proxy.proxy_server import CustomAuthSettings
custom_auth_settings: Optional[CustomAuthSettings] = None
class EnterpriseProxyConfig:
async def load_custom_auth_settings(
self, general_settings: dict
) -> CustomAuthSettings:
custom_auth_settings = general_settings.get("custom_auth_settings", None)
if custom_auth_settings is not None:
custom_auth_settings = CustomAuthSettings(
mode=custom_auth_settings.get("mode"),
)
return custom_auth_settings
async def load_enterprise_config(self, general_settings: dict) -> None:
global custom_auth_settings
custom_auth_settings = await self.load_custom_auth_settings(general_settings)
return None
@staticmethod
def get_custom_docs_description() -> Optional[str]:
from litellm.proxy.proxy_server import premium_user
docs_description: Optional[str] = None
if premium_user:
# check if premium_user has custom_docs_description
docs_description = os.getenv("DOCS_DESCRIPTION")
return docs_description

View File

@@ -0,0 +1,11 @@
# LiteLLM Proxy Enterprise Features - Readme
## Overview
This directory contains enterprise features used on the LiteLLM proxy.
## Format
Create a file for every group of endpoints (e.g. `key_management_endpoints.py`, `user_management_endpoints.py`, etc.)
If there is a broader semantic group of endpoints, create a folder for that group (e.g. `management_endpoints`, `auth_endpoints`, etc.)

View File

@@ -0,0 +1,35 @@
from typing import Optional, Union
from litellm.secret_managers.main import str_to_bool
def _should_block_robots():
"""
Returns True if the robots.txt file should block web crawlers
Controlled by
```yaml
general_settings:
block_robots: true
```
"""
from litellm.proxy.proxy_server import (
CommonProxyErrors,
general_settings,
premium_user,
)
_block_robots: Union[bool, str] = general_settings.get("block_robots", False)
block_robots: Optional[bool] = None
if isinstance(_block_robots, bool):
block_robots = _block_robots
elif isinstance(_block_robots, str):
block_robots = str_to_bool(_block_robots)
if block_robots is True:
if premium_user is not True:
raise ValueError(
f"Blocking web crawlers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}"
)
return True
return False

View File

@@ -0,0 +1,295 @@
"""
VECTOR STORE MANAGEMENT
All /vector_store management endpoints
/vector_store/new
/vector_store/delete
/vector_store/list
"""
import copy
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Request, Response
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.vector_stores import (
LiteLLM_ManagedVectorStore,
LiteLLM_ManagedVectorStoreListResponse,
VectorStoreDeleteRequest,
VectorStoreInfoRequest,
VectorStoreUpdateRequest,
)
from litellm.vector_stores.vector_store_registry import VectorStoreRegistry
router = APIRouter()
########################################################
# Management Endpoints
########################################################
@router.post(
"/vector_store/new",
tags=["vector store management"],
dependencies=[Depends(user_api_key_auth)],
)
async def new_vector_store(
vector_store: LiteLLM_ManagedVectorStore,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new vector store.
Parameters:
- vector_store_id: str - Unique identifier for the vector store
- custom_llm_provider: str - Provider of the vector store
- vector_store_name: Optional[str] - Name of the vector store
- vector_store_description: Optional[str] - Description of the vector store
- vector_store_metadata: Optional[Dict] - Additional metadata for the vector store
"""
from litellm.proxy.proxy_server import prisma_client
from litellm.types.router import GenericLiteLLMParams
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Check if vector store already exists
existing_vector_store = (
await prisma_client.db.litellm_managedvectorstorestable.find_unique(
where={"vector_store_id": vector_store.get("vector_store_id")}
)
)
if existing_vector_store is not None:
raise HTTPException(
status_code=400,
detail=f"Vector store with ID {vector_store.get('vector_store_id')} already exists",
)
if vector_store.get("vector_store_metadata") is not None:
vector_store["vector_store_metadata"] = safe_dumps(
vector_store.get("vector_store_metadata")
)
# Safely handle JSON serialization of litellm_params
litellm_params_json: Optional[str] = None
_input_litellm_params: dict = vector_store.get("litellm_params", {}) or {}
if _input_litellm_params is not None:
litellm_params_dict = GenericLiteLLMParams(**_input_litellm_params).model_dump(exclude_none=True)
litellm_params_json = safe_dumps(litellm_params_dict)
del vector_store["litellm_params"]
_new_vector_store = (
await prisma_client.db.litellm_managedvectorstorestable.create(
data={
**vector_store,
"litellm_params": litellm_params_json,
}
)
)
new_vector_store: LiteLLM_ManagedVectorStore = LiteLLM_ManagedVectorStore(
**_new_vector_store.model_dump()
)
# Add vector store to registry
if litellm.vector_store_registry is not None:
litellm.vector_store_registry.add_vector_store_to_registry(
vector_store=new_vector_store
)
return {
"status": "success",
"message": f"Vector store {vector_store.get('vector_store_id')} created successfully",
"vector_store": new_vector_store,
}
except Exception as e:
verbose_proxy_logger.exception(f"Error creating vector store: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/vector_store/list",
tags=["vector store management"],
dependencies=[Depends(user_api_key_auth)],
response_model=LiteLLM_ManagedVectorStoreListResponse,
)
async def list_vector_stores(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
page: int = 1,
page_size: int = 100,
):
"""
List all available vector stores with optional filtering and pagination.
Combines both in-memory vector stores and those stored in the database.
Parameters:
- page: int - Page number for pagination (default: 1)
- page_size: int - Number of items per page (default: 100)
"""
from litellm.proxy.proxy_server import prisma_client
seen_vector_store_ids = set()
try:
# Get in-memory vector stores
in_memory_vector_stores: List[LiteLLM_ManagedVectorStore] = []
if litellm.vector_store_registry is not None:
in_memory_vector_stores = copy.deepcopy(
litellm.vector_store_registry.vector_stores
)
# Get vector stores from database
vector_stores_from_db = await VectorStoreRegistry._get_vector_stores_from_db(
prisma_client=prisma_client
)
# Combine in-memory and database vector stores
combined_vector_stores: List[LiteLLM_ManagedVectorStore] = []
for vector_store in in_memory_vector_stores + vector_stores_from_db:
vector_store_id = vector_store.get("vector_store_id", None)
if vector_store_id not in seen_vector_store_ids:
combined_vector_stores.append(vector_store)
seen_vector_store_ids.add(vector_store_id)
total_count = len(combined_vector_stores)
total_pages = (total_count + page_size - 1) // page_size
# Format response using LiteLLM_ManagedVectorStoreListResponse
response = LiteLLM_ManagedVectorStoreListResponse(
object="list",
data=combined_vector_stores,
total_count=total_count,
current_page=page,
total_pages=total_pages,
)
return response
except Exception as e:
verbose_proxy_logger.exception(f"Error listing vector stores: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/vector_store/delete",
tags=["vector store management"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_vector_store(
data: VectorStoreDeleteRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete a vector store.
Parameters:
- vector_store_id: str - ID of the vector store to delete
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Check if vector store exists
existing_vector_store = (
await prisma_client.db.litellm_managedvectorstorestable.find_unique(
where={"vector_store_id": data.vector_store_id}
)
)
if existing_vector_store is None:
raise HTTPException(
status_code=404,
detail=f"Vector store with ID {data.vector_store_id} not found",
)
# Delete vector store
await prisma_client.db.litellm_managedvectorstorestable.delete(
where={"vector_store_id": data.vector_store_id}
)
# Delete vector store from registry
if litellm.vector_store_registry is not None:
litellm.vector_store_registry.delete_vector_store_from_registry(
vector_store_id=data.vector_store_id
)
return {"message": f"Vector store {data.vector_store_id} deleted successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/vector_store/info",
tags=["vector store management"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_vector_store_info(
data: VectorStoreInfoRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""Return a single vector store's details"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
vector_store = await prisma_client.db.litellm_managedvectorstorestable.find_unique(
where={"vector_store_id": data.vector_store_id}
)
if vector_store is None:
raise HTTPException(
status_code=404,
detail=f"Vector store with ID {data.vector_store_id} not found",
)
vector_store_dict = vector_store.model_dump()
return {"vector_store": vector_store_dict}
except Exception as e:
verbose_proxy_logger.exception(f"Error getting vector store info: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/vector_store/update",
tags=["vector store management"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_vector_store(
data: VectorStoreUpdateRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""Update vector store details"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
update_data = data.model_dump(exclude_unset=True)
vector_store_id = update_data.pop("vector_store_id")
if update_data.get("vector_store_metadata") is not None:
update_data["vector_store_metadata"] = safe_dumps(update_data["vector_store_metadata"])
updated = await prisma_client.db.litellm_managedvectorstorestable.update(
where={"vector_store_id": vector_store_id},
data=update_data,
)
updated_vs = LiteLLM_ManagedVectorStore(**updated.model_dump())
if litellm.vector_store_registry is not None:
litellm.vector_store_registry.update_vector_store_in_registry(
vector_store_id=vector_store_id,
updated_data=updated_vs,
)
return {"vector_store": updated_vs}
except Exception as e:
verbose_proxy_logger.exception(f"Error updating vector store: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))