Files
Homelab/Development/litellm/enterprise/litellm_enterprise/enterprise_callbacks/callback_controls.py

92 lines
4.6 KiB
Python

from typing import List, Optional
import litellm
from litellm._logging import verbose_logger
from litellm.constants import X_LITELLM_DISABLE_CALLBACKS
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.llm_request_utils import (
get_proxy_server_request_headers,
)
from litellm.proxy._types import CommonProxyErrors
from litellm.types.utils import StandardCallbackDynamicParams
class EnterpriseCallbackControls:
@staticmethod
def is_callback_disabled_dynamically(
callback: litellm.CALLBACK_TYPES,
litellm_params: dict,
standard_callback_dynamic_params: StandardCallbackDynamicParams
) -> bool:
"""
Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params.
Args:
callback: The callback to check (can be string, CustomLogger instance, or callable)
litellm_params: Parameters containing proxy server request info
Returns:
bool: True if the callback should be disabled, False otherwise
"""
from litellm.litellm_core_utils.custom_logger_registry import (
CustomLoggerRegistry,
)
try:
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(litellm_params, standard_callback_dynamic_params)
verbose_logger.debug(f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}")
verbose_logger.debug(f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}")
if disabled_callbacks is not None:
#########################################################
# premium user check
#########################################################
if not EnterpriseCallbackControls._premium_user_check():
return False
#########################################################
if isinstance(callback, str):
if callback.lower() in disabled_callbacks:
verbose_logger.debug(f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
return True
elif isinstance(callback, CustomLogger):
# get the string name of the callback
callback_str = CustomLoggerRegistry.get_callback_str_from_class_type(callback.__class__)
if callback_str is not None and callback_str.lower() in disabled_callbacks:
verbose_logger.debug(f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
return True
return False
except Exception as e:
verbose_logger.debug(
f"Error checking disabled callbacks header: {str(e)}"
)
return False
@staticmethod
def get_disabled_callbacks(litellm_params: dict, standard_callback_dynamic_params: StandardCallbackDynamicParams) -> Optional[List[str]]:
"""
Get the disabled callbacks from the standard callback dynamic params.
"""
#########################################################
# check if disabled via headers
#########################################################
request_headers = get_proxy_server_request_headers(litellm_params)
disabled_callbacks = request_headers.get(X_LITELLM_DISABLE_CALLBACKS, None)
if disabled_callbacks is not None:
disabled_callbacks = set([cb.strip().lower() for cb in disabled_callbacks.split(",")])
return list(disabled_callbacks)
#########################################################
# check if disabled via request body
#########################################################
if standard_callback_dynamic_params.get("litellm_disabled_callbacks", None) is not None:
return standard_callback_dynamic_params.get("litellm_disabled_callbacks", None)
return None
@staticmethod
def _premium_user_check():
from litellm.proxy.proxy_server import premium_user
if premium_user:
return True
verbose_logger.warning(f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}")
return False