Added LiteLLM to the stack
This commit is contained in:
211
Development/litellm/tests/documentation_tests/test_api_docs.py
Normal file
211
Development/litellm/tests/documentation_tests/test_api_docs.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import ast
|
||||
from typing import List, Dict, Set, Optional
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
import argparse
|
||||
import re
|
||||
import sys
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionInfo:
|
||||
"""Store function information."""
|
||||
|
||||
name: str
|
||||
docstring: Optional[str]
|
||||
parameters: Set[str]
|
||||
file_path: str
|
||||
line_number: int
|
||||
|
||||
|
||||
class FastAPIDocVisitor(ast.NodeVisitor):
|
||||
"""AST visitor to find FastAPI endpoint functions."""
|
||||
|
||||
def __init__(self, target_functions: Set[str]):
|
||||
self.target_functions = target_functions
|
||||
self.functions: Dict[str, FunctionInfo] = {}
|
||||
self.current_file = ""
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
|
||||
"""Visit function definitions (both async and sync) and collect info if they match target functions."""
|
||||
if node.name in self.target_functions:
|
||||
# Extract docstring
|
||||
docstring = ast.get_docstring(node)
|
||||
|
||||
# Extract parameters
|
||||
parameters = set()
|
||||
for arg in node.args.args:
|
||||
if arg.annotation is not None:
|
||||
# Get the parameter type from annotation
|
||||
if isinstance(arg.annotation, ast.Name):
|
||||
parameters.add((arg.arg, arg.annotation.id))
|
||||
elif isinstance(arg.annotation, ast.Subscript):
|
||||
if isinstance(arg.annotation.value, ast.Name):
|
||||
parameters.add((arg.arg, arg.annotation.value.id))
|
||||
|
||||
self.functions[node.name] = FunctionInfo(
|
||||
name=node.name,
|
||||
docstring=docstring,
|
||||
parameters=parameters,
|
||||
file_path=self.current_file,
|
||||
line_number=node.lineno,
|
||||
)
|
||||
|
||||
# Also need to add this to handle async functions
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
||||
"""Handle async functions by delegating to the regular function visitor."""
|
||||
return self.visit_FunctionDef(node)
|
||||
|
||||
|
||||
def find_functions_in_file(
|
||||
file_path: str, target_functions: Set[str]
|
||||
) -> Dict[str, FunctionInfo]:
|
||||
"""Find target functions in a Python file using AST."""
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
visitor = FastAPIDocVisitor(target_functions)
|
||||
visitor.current_file = file_path
|
||||
tree = ast.parse(content)
|
||||
visitor.visit(tree)
|
||||
return visitor.functions
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing {file_path}: {str(e)}")
|
||||
return {}
|
||||
|
||||
|
||||
def extract_docstring_params(docstring: Optional[str]) -> Set[str]:
|
||||
"""Extract parameter names from docstring."""
|
||||
if not docstring:
|
||||
return set()
|
||||
|
||||
params = set()
|
||||
# Match parameters in format:
|
||||
# - parameter_name: description
|
||||
# or
|
||||
# parameter_name: description
|
||||
param_pattern = r"-?\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*(?:\([^)]*\))?\s*:"
|
||||
|
||||
for match in re.finditer(param_pattern, docstring):
|
||||
params.add(match.group(1))
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def analyze_function(func_info: FunctionInfo) -> Dict:
|
||||
"""Analyze function documentation and return validation results."""
|
||||
|
||||
docstring_params = extract_docstring_params(func_info.docstring)
|
||||
|
||||
print(f"func_info.parameters: {func_info.parameters}")
|
||||
pydantic_params = set()
|
||||
|
||||
for name, type_name in func_info.parameters:
|
||||
if type_name.endswith("Request") or type_name.endswith("Response"):
|
||||
pydantic_model = getattr(litellm.proxy._types, type_name, None)
|
||||
if pydantic_model is not None:
|
||||
for param in pydantic_model.model_fields.keys():
|
||||
pydantic_params.add(param)
|
||||
|
||||
print(f"pydantic_params: {pydantic_params}")
|
||||
|
||||
missing_params = pydantic_params - docstring_params
|
||||
|
||||
return {
|
||||
"function": func_info.name,
|
||||
"file_path": func_info.file_path,
|
||||
"line_number": func_info.line_number,
|
||||
"has_docstring": bool(func_info.docstring),
|
||||
"pydantic_params": list(pydantic_params),
|
||||
"documented_params": list(docstring_params),
|
||||
"missing_params": list(missing_params),
|
||||
"is_valid": len(missing_params) == 0,
|
||||
}
|
||||
|
||||
|
||||
def print_validation_results(results: Dict) -> None:
|
||||
"""Print validation results in a readable format."""
|
||||
print(f"\nChecking function: {results['function']}")
|
||||
print(f"File: {results['file_path']}:{results['line_number']}")
|
||||
print("-" * 50)
|
||||
|
||||
if not results["has_docstring"]:
|
||||
print("❌ No docstring found!")
|
||||
return
|
||||
|
||||
if not results["pydantic_params"]:
|
||||
print("ℹ️ No Pydantic input models found.")
|
||||
return
|
||||
|
||||
if results["is_valid"]:
|
||||
print("✅ All Pydantic parameters are documented!")
|
||||
else:
|
||||
print("❌ Missing documentation for parameters:")
|
||||
for param in sorted(results["missing_params"]):
|
||||
print(f" - {param}")
|
||||
|
||||
|
||||
def main():
|
||||
function_names = [
|
||||
"new_end_user",
|
||||
"end_user_info",
|
||||
"update_end_user",
|
||||
"delete_end_user",
|
||||
"generate_key_fn",
|
||||
"info_key_fn",
|
||||
"update_key_fn",
|
||||
"delete_key_fn",
|
||||
"new_user",
|
||||
"new_team",
|
||||
"team_info",
|
||||
"update_team",
|
||||
"delete_team",
|
||||
"new_organization",
|
||||
"update_organization",
|
||||
"delete_organization",
|
||||
"list_organization",
|
||||
"user_update",
|
||||
"new_budget",
|
||||
"info_budget",
|
||||
"update_budget",
|
||||
"delete_budget",
|
||||
"list_budget",
|
||||
]
|
||||
# directory = "../../litellm/proxy/management_endpoints" # LOCAL
|
||||
directory = "./litellm/proxy/management_endpoints"
|
||||
|
||||
# Convert function names to set for faster lookup
|
||||
target_functions = set(function_names)
|
||||
found_functions: Dict[str, FunctionInfo] = {}
|
||||
|
||||
# Walk through directory
|
||||
for root, _, files in os.walk(directory):
|
||||
for file in files:
|
||||
if file.endswith(".py"):
|
||||
file_path = os.path.join(root, file)
|
||||
found = find_functions_in_file(file_path, target_functions)
|
||||
found_functions.update(found)
|
||||
|
||||
# Analyze and output results
|
||||
for func_name in function_names:
|
||||
if func_name in found_functions:
|
||||
result = analyze_function(found_functions[func_name])
|
||||
if not result["is_valid"]:
|
||||
raise Exception(print_validation_results(result))
|
||||
# results.append(result)
|
||||
# print_validation_results(result)
|
||||
|
||||
# # Exit with error code if any validation failed
|
||||
# if any(not r["is_valid"] for r in results):
|
||||
# exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -0,0 +1,162 @@
|
||||
import os
|
||||
import ast
|
||||
import sys
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
|
||||
def find_litellm_type_hints(directory: str) -> List[Tuple[str, int, str]]:
|
||||
"""
|
||||
Recursively search for Python files in the given directory
|
||||
and find type hints containing 'litellm.'.
|
||||
|
||||
Args:
|
||||
directory (str): The root directory to search for Python files
|
||||
|
||||
Returns:
|
||||
List of tuples containing (file_path, line_number, type_hint)
|
||||
"""
|
||||
litellm_type_hints = []
|
||||
|
||||
def is_litellm_type_hint(node):
|
||||
"""
|
||||
Recursively check if a type annotation contains 'litellm.'
|
||||
|
||||
Handles more complex type hints like:
|
||||
- Optional[litellm.Type]
|
||||
- Union[litellm.Type1, litellm.Type2]
|
||||
- Nested type hints
|
||||
"""
|
||||
try:
|
||||
# Convert node to string representation
|
||||
type_str = ast.unparse(node)
|
||||
|
||||
# Direct check for litellm in type string
|
||||
if "litellm." in type_str:
|
||||
return True
|
||||
|
||||
# Handle more complex type hints
|
||||
if isinstance(node, ast.Subscript):
|
||||
# Check Union or Optional types
|
||||
if isinstance(node.value, ast.Name) and node.value.id in [
|
||||
"Union",
|
||||
"Optional",
|
||||
]:
|
||||
# Check each element in the Union/Optional type
|
||||
if isinstance(node.slice, ast.Tuple):
|
||||
return any(is_litellm_type_hint(elt) for elt in node.slice.elts)
|
||||
else:
|
||||
return is_litellm_type_hint(node.slice)
|
||||
|
||||
# Recursive check for subscripted types
|
||||
return is_litellm_type_hint(node.value) or is_litellm_type_hint(
|
||||
node.slice
|
||||
)
|
||||
|
||||
# Recursive check for attribute types
|
||||
if isinstance(node, ast.Attribute):
|
||||
return "litellm." in ast.unparse(node)
|
||||
|
||||
# Recursive check for name types
|
||||
if isinstance(node, ast.Name):
|
||||
return "litellm" in node.id
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
# Fallback to string checking if parsing fails
|
||||
try:
|
||||
return "litellm." in ast.unparse(node)
|
||||
except:
|
||||
return False
|
||||
|
||||
def scan_file(file_path: str):
|
||||
"""
|
||||
Scan a single Python file for LiteLLM type hints
|
||||
"""
|
||||
try:
|
||||
# Use utf-8-sig to handle files with BOM, ignore errors
|
||||
with open(file_path, "r", encoding="utf-8-sig", errors="ignore") as file:
|
||||
tree = ast.parse(file.read())
|
||||
|
||||
for node in ast.walk(tree):
|
||||
# Check type annotations in variable annotations
|
||||
if isinstance(node, ast.AnnAssign) and node.annotation:
|
||||
if is_litellm_type_hint(node.annotation):
|
||||
litellm_type_hints.append(
|
||||
(file_path, node.lineno, ast.unparse(node.annotation))
|
||||
)
|
||||
|
||||
# Check type hints in function arguments
|
||||
elif isinstance(node, ast.FunctionDef):
|
||||
for arg in node.args.args:
|
||||
if arg.annotation and is_litellm_type_hint(arg.annotation):
|
||||
litellm_type_hints.append(
|
||||
(file_path, arg.lineno, ast.unparse(arg.annotation))
|
||||
)
|
||||
|
||||
# Check return type annotation
|
||||
if node.returns and is_litellm_type_hint(node.returns):
|
||||
litellm_type_hints.append(
|
||||
(file_path, node.lineno, ast.unparse(node.returns))
|
||||
)
|
||||
except SyntaxError as e:
|
||||
print(f"Syntax error in {file_path}: {e}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Error processing {file_path}: {e}", file=sys.stderr)
|
||||
|
||||
# Recursively walk through directory
|
||||
for root, dirs, files in os.walk(directory):
|
||||
# Remove virtual environment and cache directories from search
|
||||
dirs[:] = [
|
||||
d
|
||||
for d in dirs
|
||||
if not any(
|
||||
venv in d
|
||||
for venv in [
|
||||
"venv",
|
||||
"env",
|
||||
"myenv",
|
||||
".venv",
|
||||
"__pycache__",
|
||||
".pytest_cache",
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
for file in files:
|
||||
if file.endswith(".py"):
|
||||
full_path = os.path.join(root, file)
|
||||
# Skip files in virtual environment or cache directories
|
||||
if not any(
|
||||
venv in full_path
|
||||
for venv in [
|
||||
"venv",
|
||||
"env",
|
||||
"myenv",
|
||||
".venv",
|
||||
"__pycache__",
|
||||
".pytest_cache",
|
||||
]
|
||||
):
|
||||
scan_file(full_path)
|
||||
|
||||
return litellm_type_hints
|
||||
|
||||
|
||||
def main():
|
||||
# Get directory from command line argument or use current directory
|
||||
directory = "./litellm/"
|
||||
|
||||
# Find LiteLLM type hints
|
||||
results = find_litellm_type_hints(directory)
|
||||
|
||||
# Print results
|
||||
if results:
|
||||
print("LiteLLM Type Hints Found:")
|
||||
for file_path, line_num, type_hint in results:
|
||||
print(f"{file_path}:{line_num} - {type_hint}")
|
||||
else:
|
||||
print("No LiteLLM type hints found.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -0,0 +1,96 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
# Define the base directory for the litellm repository and documentation path
|
||||
repo_base = "./litellm" # Change this to your actual path
|
||||
|
||||
# Regular expressions to capture the keys used in os.getenv() and litellm.get_secret()
|
||||
getenv_pattern = re.compile(r'os\.getenv\(\s*[\'"]([^\'"]+)[\'"]\s*(?:,\s*[^)]*)?\)')
|
||||
get_secret_pattern = re.compile(
|
||||
r'litellm\.get_secret\(\s*[\'"]([^\'"]+)[\'"]\s*(?:,\s*[^)]*|,\s*default_value=[^)]*)?\)'
|
||||
)
|
||||
get_secret_str_pattern = re.compile(
|
||||
r'litellm\.get_secret_str\(\s*[\'"]([^\'"]+)[\'"]\s*(?:,\s*[^)]*|,\s*default_value=[^)]*)?\)'
|
||||
)
|
||||
|
||||
# Set to store unique keys from the code
|
||||
env_keys = set()
|
||||
|
||||
# Walk through all files in the litellm repo to find references of os.getenv() and litellm.get_secret()
|
||||
for root, dirs, files in os.walk(repo_base):
|
||||
for file in files:
|
||||
if file.endswith(".py"): # Only process Python files
|
||||
file_path = os.path.join(root, file)
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Find all keys using os.getenv()
|
||||
getenv_matches = getenv_pattern.findall(content)
|
||||
env_keys.update(
|
||||
match for match in getenv_matches
|
||||
) # Extract only the key part
|
||||
|
||||
# Find all keys using litellm.get_secret()
|
||||
get_secret_matches = get_secret_pattern.findall(content)
|
||||
env_keys.update(match for match in get_secret_matches)
|
||||
|
||||
# Find all keys using litellm.get_secret_str()
|
||||
get_secret_str_matches = get_secret_str_pattern.findall(content)
|
||||
env_keys.update(match for match in get_secret_str_matches)
|
||||
|
||||
# Print the unique keys found
|
||||
print(env_keys)
|
||||
|
||||
|
||||
# Parse the documentation to extract documented keys
|
||||
repo_base = "./"
|
||||
print(os.listdir(repo_base))
|
||||
docs_path = (
|
||||
"./docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
|
||||
)
|
||||
documented_keys = set()
|
||||
try:
|
||||
with open(docs_path, "r", encoding="utf-8") as docs_file:
|
||||
content = docs_file.read()
|
||||
|
||||
print(f"content: {content}")
|
||||
|
||||
# Find the section titled "general_settings - Reference"
|
||||
general_settings_section = re.search(
|
||||
r"### environment variables - Reference(.*?)(?=\n###|\Z)",
|
||||
content,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
print(f"general_settings_section: {general_settings_section}")
|
||||
if general_settings_section:
|
||||
# Extract the table rows, which contain the documented keys
|
||||
table_content = general_settings_section.group(1)
|
||||
doc_key_pattern = re.compile(
|
||||
r"\|\s*([^\|]+?)\s*\|"
|
||||
) # Capture the key from each row of the table
|
||||
documented_keys.update(doc_key_pattern.findall(table_content))
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Error reading documentation: {e}, \n repo base - {os.listdir(repo_base)}"
|
||||
)
|
||||
|
||||
|
||||
print(f"documented_keys: {documented_keys}")
|
||||
# Compare and find undocumented keys
|
||||
undocumented_keys = env_keys - documented_keys
|
||||
|
||||
# Print results
|
||||
print("Keys expected in 'environment settings' (found in code):")
|
||||
for key in sorted(env_keys):
|
||||
print(key)
|
||||
|
||||
if undocumented_keys:
|
||||
raise Exception(
|
||||
f"\nKeys not documented in 'environment settings - Reference': {undocumented_keys}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"\nAll keys are documented in 'environment settings - Reference'. - {}".format(
|
||||
env_keys
|
||||
)
|
||||
)
|
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
import re
|
||||
|
||||
# Backup the original sys.path
|
||||
original_sys_path = sys.path.copy()
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
public_exceptions = litellm.LITELLM_EXCEPTION_TYPES
|
||||
# Regular expression to extract the error name
|
||||
error_name_pattern = re.compile(r"\.exceptions\.([A-Za-z]+Error)")
|
||||
|
||||
# Extract error names from each item
|
||||
error_names = {
|
||||
error_name_pattern.search(str(item)).group(1)
|
||||
for item in public_exceptions
|
||||
if error_name_pattern.search(str(item))
|
||||
}
|
||||
|
||||
|
||||
# sys.path = original_sys_path
|
||||
|
||||
|
||||
# Parse the documentation to extract documented keys
|
||||
# repo_base = "./"
|
||||
repo_base = "../../"
|
||||
print(os.listdir(repo_base))
|
||||
docs_path = f"{repo_base}/docs/my-website/docs/exception_mapping.md" # Path to the documentation
|
||||
documented_keys = set()
|
||||
try:
|
||||
with open(docs_path, "r", encoding="utf-8") as docs_file:
|
||||
content = docs_file.read()
|
||||
|
||||
exceptions_section = re.search(
|
||||
r"## LiteLLM Exceptions(.*?)\n##", content, re.DOTALL
|
||||
)
|
||||
if exceptions_section:
|
||||
# Step 2: Extract the table content
|
||||
table_content = exceptions_section.group(1)
|
||||
|
||||
# Step 3: Create a pattern to capture the Error Types from each row
|
||||
error_type_pattern = re.compile(r"\|\s*[^|]+\s*\|\s*([^\|]+?)\s*\|")
|
||||
|
||||
# Extract the error types
|
||||
exceptions = error_type_pattern.findall(table_content)
|
||||
print(f"exceptions: {exceptions}")
|
||||
|
||||
# Remove extra spaces if any
|
||||
exceptions = [exception.strip() for exception in exceptions]
|
||||
|
||||
print(exceptions)
|
||||
documented_keys.update(exceptions)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Error reading documentation: {e}, \n repo base - {os.listdir(repo_base)}"
|
||||
)
|
||||
|
||||
print(documented_keys)
|
||||
print(public_exceptions)
|
||||
print(error_names)
|
||||
|
||||
# Compare and find undocumented keys
|
||||
undocumented_keys = error_names - documented_keys
|
||||
|
||||
if undocumented_keys:
|
||||
raise Exception(
|
||||
f"\nKeys not documented in 'LiteLLM Exceptions': {undocumented_keys}"
|
||||
)
|
||||
else:
|
||||
print("\nAll keys are documented in 'LiteLLM Exceptions'. - {}".format(error_names))
|
@@ -0,0 +1,78 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
# Define the base directory for the litellm repository and documentation path
|
||||
repo_base = "./litellm" # Change this to your actual path
|
||||
|
||||
|
||||
# Regular expressions to capture the keys used in general_settings.get() and general_settings[]
|
||||
get_pattern = re.compile(
|
||||
r'general_settings\.get\(\s*[\'"]([^\'"]+)[\'"](,?\s*[^)]*)?\)'
|
||||
)
|
||||
bracket_pattern = re.compile(r'general_settings\[\s*[\'"]([^\'"]+)[\'"]\s*\]')
|
||||
|
||||
# Set to store unique keys from the code
|
||||
general_settings_keys = set()
|
||||
|
||||
# Walk through all files in the litellm repo to find references of general_settings
|
||||
for root, dirs, files in os.walk(repo_base):
|
||||
for file in files:
|
||||
if file.endswith(".py"): # Only process Python files
|
||||
file_path = os.path.join(root, file)
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
# Find all keys using general_settings.get()
|
||||
get_matches = get_pattern.findall(content)
|
||||
general_settings_keys.update(
|
||||
match[0] for match in get_matches
|
||||
) # Extract only the key part
|
||||
|
||||
# Find all keys using general_settings[]
|
||||
bracket_matches = bracket_pattern.findall(content)
|
||||
general_settings_keys.update(bracket_matches)
|
||||
|
||||
# Parse the documentation to extract documented keys
|
||||
repo_base = "./"
|
||||
print(os.listdir(repo_base))
|
||||
docs_path = (
|
||||
"./docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
|
||||
)
|
||||
documented_keys = set()
|
||||
try:
|
||||
with open(docs_path, "r", encoding="utf-8") as docs_file:
|
||||
content = docs_file.read()
|
||||
|
||||
# Find the section titled "general_settings - Reference"
|
||||
general_settings_section = re.search(
|
||||
r"### general_settings - Reference(.*?)###", content, re.DOTALL
|
||||
)
|
||||
if general_settings_section:
|
||||
# Extract the table rows, which contain the documented keys
|
||||
table_content = general_settings_section.group(1)
|
||||
doc_key_pattern = re.compile(
|
||||
r"\|\s*([^\|]+?)\s*\|"
|
||||
) # Capture the key from each row of the table
|
||||
documented_keys.update(doc_key_pattern.findall(table_content))
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Error reading documentation: {e}, \n repo base - {os.listdir(repo_base)}"
|
||||
)
|
||||
|
||||
# Compare and find undocumented keys
|
||||
undocumented_keys = general_settings_keys - documented_keys
|
||||
|
||||
# Print results
|
||||
print("Keys expected in 'general_settings' (found in code):")
|
||||
for key in sorted(general_settings_keys):
|
||||
print(key)
|
||||
|
||||
if undocumented_keys:
|
||||
raise Exception(
|
||||
f"\nKeys not documented in 'general_settings - Reference': {undocumented_keys}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"\nAll keys are documented in 'general_settings - Reference'. - {}".format(
|
||||
general_settings_keys
|
||||
)
|
||||
)
|
@@ -0,0 +1,151 @@
|
||||
import ast
|
||||
from typing import List, Set, Dict, Optional
|
||||
import sys
|
||||
|
||||
|
||||
class ConfigChecker(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.errors: List[str] = []
|
||||
self.current_provider_block: Optional[str] = None
|
||||
self.param_assignments: Dict[str, Set[str]] = {}
|
||||
self.map_openai_calls: Set[str] = set()
|
||||
self.class_inheritance: Dict[str, List[str]] = {}
|
||||
|
||||
def get_full_name(self, node):
|
||||
"""Recursively extract the full name from a node."""
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
elif isinstance(node, ast.Attribute):
|
||||
base = self.get_full_name(node.value)
|
||||
if base:
|
||||
return f"{base}.{node.attr}"
|
||||
return None
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef):
|
||||
# Record class inheritance
|
||||
bases = [base.id for base in node.bases if isinstance(base, ast.Name)]
|
||||
print(f"Found class {node.name} with bases {bases}")
|
||||
self.class_inheritance[node.name] = bases
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Call(self, node: ast.Call):
|
||||
# Check for map_openai_params calls
|
||||
if (
|
||||
isinstance(node.func, ast.Attribute)
|
||||
and node.func.attr == "map_openai_params"
|
||||
):
|
||||
if isinstance(node.func.value, ast.Name):
|
||||
config_name = node.func.value.id
|
||||
self.map_openai_calls.add(config_name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_If(self, node: ast.If):
|
||||
# Detect custom_llm_provider blocks
|
||||
provider = self._extract_provider_from_if(node)
|
||||
if provider:
|
||||
old_provider = self.current_provider_block
|
||||
self.current_provider_block = provider
|
||||
self.generic_visit(node)
|
||||
self.current_provider_block = old_provider
|
||||
else:
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Assign(self, node: ast.Assign):
|
||||
# Track assignments to optional_params
|
||||
if self.current_provider_block and len(node.targets) == 1:
|
||||
target = node.targets[0]
|
||||
if isinstance(target, ast.Subscript) and isinstance(target.value, ast.Name):
|
||||
if target.value.id == "optional_params":
|
||||
if isinstance(target.slice, ast.Constant):
|
||||
key = target.slice.value
|
||||
if self.current_provider_block not in self.param_assignments:
|
||||
self.param_assignments[self.current_provider_block] = set()
|
||||
self.param_assignments[self.current_provider_block].add(key)
|
||||
self.generic_visit(node)
|
||||
|
||||
def _extract_provider_from_if(self, node: ast.If) -> Optional[str]:
|
||||
"""Extract the provider name from an if condition checking custom_llm_provider"""
|
||||
if isinstance(node.test, ast.Compare):
|
||||
if len(node.test.ops) == 1 and isinstance(node.test.ops[0], ast.Eq):
|
||||
if (
|
||||
isinstance(node.test.left, ast.Name)
|
||||
and node.test.left.id == "custom_llm_provider"
|
||||
):
|
||||
if isinstance(node.test.comparators[0], ast.Constant):
|
||||
return node.test.comparators[0].value
|
||||
return None
|
||||
|
||||
def check_patterns(self) -> List[str]:
|
||||
# Check if all configs using map_openai_params inherit from BaseConfig
|
||||
for config_name in self.map_openai_calls:
|
||||
print(f"Checking config: {config_name}")
|
||||
if (
|
||||
config_name not in self.class_inheritance
|
||||
or "BaseConfig" not in self.class_inheritance[config_name]
|
||||
):
|
||||
# Retrieve the associated class name, if any
|
||||
class_name = next(
|
||||
(
|
||||
cls
|
||||
for cls, bases in self.class_inheritance.items()
|
||||
if config_name in bases
|
||||
),
|
||||
"Unknown Class",
|
||||
)
|
||||
self.errors.append(
|
||||
f"Error: {config_name} calls map_openai_params but doesn't inherit from BaseConfig. "
|
||||
f"It is used in the class: {class_name}"
|
||||
)
|
||||
|
||||
# Check for parameter assignments in provider blocks
|
||||
for provider, params in self.param_assignments.items():
|
||||
# You can customize which parameters should raise warnings for each provider
|
||||
for param in params:
|
||||
if param not in self._get_allowed_params(provider):
|
||||
self.errors.append(
|
||||
f"Warning: Parameter '{param}' is directly assigned in {provider} block. "
|
||||
f"Consider using a config class instead."
|
||||
)
|
||||
|
||||
return self.errors
|
||||
|
||||
def _get_allowed_params(self, provider: str) -> Set[str]:
|
||||
"""Define allowed direct parameter assignments for each provider"""
|
||||
# You can customize this based on your requirements
|
||||
common_allowed = {"stream", "api_key", "api_base"}
|
||||
provider_specific = {
|
||||
"anthropic": {"api_version"},
|
||||
"openai": {"organization"},
|
||||
# Add more providers and their allowed params here
|
||||
}
|
||||
return common_allowed.union(provider_specific.get(provider, set()))
|
||||
|
||||
|
||||
def check_file(file_path: str) -> List[str]:
|
||||
with open(file_path, "r") as file:
|
||||
tree = ast.parse(file.read())
|
||||
|
||||
checker = ConfigChecker()
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.FunctionDef) and node.name == "get_optional_params":
|
||||
checker.visit(node)
|
||||
break # No need to visit other functions
|
||||
return checker.check_patterns()
|
||||
|
||||
|
||||
def main():
|
||||
file_path = "../../litellm/utils.py"
|
||||
errors = check_file(file_path)
|
||||
|
||||
if errors:
|
||||
print("\nFound the following issues:")
|
||||
for error in errors:
|
||||
print(f"- {error}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("No issues found!")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Prevent usage of 'requests' library in the codebase.
|
||||
"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
def find_requests_usage(directory: str) -> List[Tuple[str, int, str]]:
|
||||
"""
|
||||
Recursively search for Python files in the given directory
|
||||
and find usages of the 'requests' library.
|
||||
|
||||
Args:
|
||||
directory (str): The root directory to search for Python files
|
||||
|
||||
Returns:
|
||||
List of tuples containing (file_path, line_number, usage_type)
|
||||
"""
|
||||
requests_usages = []
|
||||
|
||||
def is_likely_requests_usage(node):
|
||||
"""
|
||||
More precise check to avoid false positives
|
||||
"""
|
||||
try:
|
||||
# Convert node to string representation
|
||||
node_str = ast.unparse(node)
|
||||
|
||||
# Specific checks to ensure it's the requests library
|
||||
requests_identifiers = [
|
||||
# HTTP methods
|
||||
"requests.get",
|
||||
"requests.post",
|
||||
"requests.put",
|
||||
"requests.delete",
|
||||
"requests.head",
|
||||
"requests.patch",
|
||||
"requests.options",
|
||||
"requests.request",
|
||||
"requests.session",
|
||||
# Types and exceptions
|
||||
"requests.Response",
|
||||
"requests.Request",
|
||||
"requests.Session",
|
||||
"requests.ConnectionError",
|
||||
"requests.HTTPError",
|
||||
"requests.Timeout",
|
||||
"requests.TooManyRedirects",
|
||||
"requests.RequestException",
|
||||
# Additional modules and attributes
|
||||
"requests.api",
|
||||
"requests.exceptions",
|
||||
"requests.models",
|
||||
"requests.auth",
|
||||
"requests.cookies",
|
||||
"requests.structures",
|
||||
]
|
||||
|
||||
# Check for specific requests library identifiers
|
||||
return any(identifier in node_str for identifier in requests_identifiers)
|
||||
except:
|
||||
return False
|
||||
|
||||
def scan_file(file_path: str):
|
||||
"""
|
||||
Scan a single Python file for requests library usage
|
||||
"""
|
||||
try:
|
||||
# Use utf-8-sig to handle files with BOM, ignore errors
|
||||
with open(file_path, "r", encoding="utf-8-sig", errors="ignore") as file:
|
||||
tree = ast.parse(file.read())
|
||||
|
||||
for node in ast.walk(tree):
|
||||
# Check import statements
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
if alias.name == "requests":
|
||||
requests_usages.append(
|
||||
(file_path, node.lineno, f"Import: {alias.name}")
|
||||
)
|
||||
|
||||
# Check import from statements
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module == "requests":
|
||||
requests_usages.append(
|
||||
(file_path, node.lineno, f"Import from: {node.module}")
|
||||
)
|
||||
|
||||
# Check method calls
|
||||
elif isinstance(node, ast.Call):
|
||||
# More precise check for requests usage
|
||||
try:
|
||||
if is_likely_requests_usage(node.func):
|
||||
requests_usages.append(
|
||||
(
|
||||
file_path,
|
||||
node.lineno,
|
||||
f"Method Call: {ast.unparse(node.func)}",
|
||||
)
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check attribute access
|
||||
elif isinstance(node, ast.Attribute):
|
||||
try:
|
||||
# More precise check
|
||||
if is_likely_requests_usage(node):
|
||||
requests_usages.append(
|
||||
(
|
||||
file_path,
|
||||
node.lineno,
|
||||
f"Attribute Access: {ast.unparse(node)}",
|
||||
)
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
except SyntaxError as e:
|
||||
print(f"Syntax error in {file_path}: {e}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Error processing {file_path}: {e}", file=sys.stderr)
|
||||
|
||||
# Recursively walk through directory
|
||||
for root, dirs, files in os.walk(directory):
|
||||
# Remove virtual environment and cache directories from search
|
||||
dirs[:] = [
|
||||
d
|
||||
for d in dirs
|
||||
if not any(
|
||||
venv in d
|
||||
for venv in [
|
||||
"venv",
|
||||
"env",
|
||||
"myenv",
|
||||
".venv",
|
||||
"__pycache__",
|
||||
".pytest_cache",
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
for file in files:
|
||||
if file.endswith(".py"):
|
||||
full_path = os.path.join(root, file)
|
||||
# Skip files in virtual environment or cache directories
|
||||
if not any(
|
||||
venv in full_path
|
||||
for venv in [
|
||||
"venv",
|
||||
"env",
|
||||
"myenv",
|
||||
".venv",
|
||||
"__pycache__",
|
||||
".pytest_cache",
|
||||
]
|
||||
):
|
||||
scan_file(full_path)
|
||||
|
||||
return requests_usages
|
||||
|
||||
|
||||
def main():
|
||||
# Get directory from command line argument or use current directory
|
||||
directory = "../../litellm"
|
||||
|
||||
# Find requests library usages
|
||||
results = find_requests_usage(directory)
|
||||
|
||||
# Print results
|
||||
if results:
|
||||
print("Requests Library Usages Found:")
|
||||
for file_path, line_num, usage_type in results:
|
||||
print(f"{file_path}:{line_num} - {usage_type}")
|
||||
else:
|
||||
print("No requests library usages found.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
import re
|
||||
import inspect
|
||||
from typing import Type
|
||||
import sys
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
|
||||
def get_init_params(cls: Type) -> list[str]:
|
||||
"""
|
||||
Retrieve all parameters supported by the `__init__` method of a given class.
|
||||
|
||||
Args:
|
||||
cls: The class to inspect.
|
||||
|
||||
Returns:
|
||||
A list of parameter names.
|
||||
"""
|
||||
if not hasattr(cls, "__init__"):
|
||||
raise ValueError(
|
||||
f"The provided class {cls.__name__} does not have an __init__ method."
|
||||
)
|
||||
|
||||
init_method = cls.__init__
|
||||
argspec = inspect.getfullargspec(init_method)
|
||||
|
||||
# The first argument is usually 'self', so we exclude it
|
||||
return argspec.args[1:] # Exclude 'self'
|
||||
|
||||
|
||||
router_init_params = set(get_init_params(litellm.router.Router))
|
||||
print(router_init_params)
|
||||
router_init_params.remove("model_list")
|
||||
|
||||
# Parse the documentation to extract documented keys
|
||||
repo_base = "./"
|
||||
print(os.listdir(repo_base))
|
||||
docs_path = (
|
||||
"./docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
|
||||
)
|
||||
# docs_path = (
|
||||
# "../../docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
|
||||
# )
|
||||
documented_keys = set()
|
||||
try:
|
||||
with open(docs_path, "r", encoding="utf-8") as docs_file:
|
||||
content = docs_file.read()
|
||||
|
||||
# Find the section titled "general_settings - Reference"
|
||||
general_settings_section = re.search(
|
||||
r"### router_settings - Reference(.*?)###", content, re.DOTALL
|
||||
)
|
||||
if general_settings_section:
|
||||
# Extract the table rows, which contain the documented keys
|
||||
table_content = general_settings_section.group(1)
|
||||
doc_key_pattern = re.compile(
|
||||
r"\|\s*([^\|]+?)\s*\|"
|
||||
) # Capture the key from each row of the table
|
||||
documented_keys.update(doc_key_pattern.findall(table_content))
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Error reading documentation: {e}, \n repo base - {os.listdir(repo_base)}"
|
||||
)
|
||||
|
||||
|
||||
# Compare and find undocumented keys
|
||||
undocumented_keys = router_init_params - documented_keys
|
||||
|
||||
# Print results
|
||||
print("Keys expected in 'router settings' (found in code):")
|
||||
for key in sorted(router_init_params):
|
||||
print(key)
|
||||
|
||||
if undocumented_keys:
|
||||
raise Exception(
|
||||
f"\nKeys not documented in 'router settings - Reference': {undocumented_keys}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"\nAll keys are documented in 'router settings - Reference'. - {}".format(
|
||||
router_init_params
|
||||
)
|
||||
)
|
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
from typing import get_type_hints
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
def get_all_fields(type_dict, prefix=""):
|
||||
"""Recursively get all fields from TypedDict and its nested types"""
|
||||
fields = set()
|
||||
|
||||
# Get type hints for the TypedDict
|
||||
hints = get_type_hints(type_dict)
|
||||
|
||||
for field_name, field_type in hints.items():
|
||||
full_field_name = f"{prefix}{field_name}" if prefix else field_name
|
||||
fields.add(full_field_name)
|
||||
|
||||
# Check if the field type is another TypedDict we should process
|
||||
if hasattr(field_type, "__annotations__"):
|
||||
nested_fields = get_all_fields(field_type)
|
||||
fields.update(nested_fields)
|
||||
return fields
|
||||
|
||||
|
||||
def test_standard_logging_payload_documentation():
|
||||
# Get all fields from StandardLoggingPayload and its nested types
|
||||
all_fields = get_all_fields(StandardLoggingPayload)
|
||||
|
||||
print("All fields in StandardLoggingPayload: ")
|
||||
for _field in all_fields:
|
||||
print(_field)
|
||||
|
||||
# Read the documentation
|
||||
docs_path = "../../docs/my-website/docs/proxy/logging_spec.md"
|
||||
|
||||
try:
|
||||
with open(docs_path, "r", encoding="utf-8") as docs_file:
|
||||
content = docs_file.read()
|
||||
|
||||
# Extract documented fields from the table
|
||||
doc_field_pattern = re.compile(r"\|\s*`([^`]+?)`\s*\|")
|
||||
documented_fields = set(doc_field_pattern.findall(content))
|
||||
|
||||
# Clean up documented fields (remove whitespace)
|
||||
documented_fields = {field.strip() for field in documented_fields}
|
||||
|
||||
# Clean up documented fields (remove whitespace)
|
||||
documented_fields = {field.strip() for field in documented_fields}
|
||||
print("\n\nDocumented fields: ")
|
||||
for _field in documented_fields:
|
||||
print(_field)
|
||||
|
||||
# Compare and find undocumented fields
|
||||
undocumented_fields = all_fields - documented_fields
|
||||
|
||||
print("\n\nUndocumented fields: ")
|
||||
for _field in undocumented_fields:
|
||||
print(_field)
|
||||
|
||||
if undocumented_fields:
|
||||
raise Exception(
|
||||
f"\nFields not documented in 'StandardLoggingPayload': {undocumented_fields}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"All {len(all_fields)} fields are documented in 'StandardLoggingPayload'"
|
||||
)
|
||||
|
||||
except FileNotFoundError:
|
||||
raise Exception(
|
||||
f"Documentation file not found at {docs_path}. Please ensure the documentation exists."
|
||||
)
|
Reference in New Issue
Block a user