Added LiteLLM to the stack
This commit is contained in:
@@ -0,0 +1,151 @@
|
||||
import ast
|
||||
from typing import List, Set, Dict, Optional
|
||||
import sys
|
||||
|
||||
|
||||
class ConfigChecker(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.errors: List[str] = []
|
||||
self.current_provider_block: Optional[str] = None
|
||||
self.param_assignments: Dict[str, Set[str]] = {}
|
||||
self.map_openai_calls: Set[str] = set()
|
||||
self.class_inheritance: Dict[str, List[str]] = {}
|
||||
|
||||
def get_full_name(self, node):
|
||||
"""Recursively extract the full name from a node."""
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
elif isinstance(node, ast.Attribute):
|
||||
base = self.get_full_name(node.value)
|
||||
if base:
|
||||
return f"{base}.{node.attr}"
|
||||
return None
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef):
|
||||
# Record class inheritance
|
||||
bases = [base.id for base in node.bases if isinstance(base, ast.Name)]
|
||||
print(f"Found class {node.name} with bases {bases}")
|
||||
self.class_inheritance[node.name] = bases
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Call(self, node: ast.Call):
|
||||
# Check for map_openai_params calls
|
||||
if (
|
||||
isinstance(node.func, ast.Attribute)
|
||||
and node.func.attr == "map_openai_params"
|
||||
):
|
||||
if isinstance(node.func.value, ast.Name):
|
||||
config_name = node.func.value.id
|
||||
self.map_openai_calls.add(config_name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_If(self, node: ast.If):
|
||||
# Detect custom_llm_provider blocks
|
||||
provider = self._extract_provider_from_if(node)
|
||||
if provider:
|
||||
old_provider = self.current_provider_block
|
||||
self.current_provider_block = provider
|
||||
self.generic_visit(node)
|
||||
self.current_provider_block = old_provider
|
||||
else:
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Assign(self, node: ast.Assign):
|
||||
# Track assignments to optional_params
|
||||
if self.current_provider_block and len(node.targets) == 1:
|
||||
target = node.targets[0]
|
||||
if isinstance(target, ast.Subscript) and isinstance(target.value, ast.Name):
|
||||
if target.value.id == "optional_params":
|
||||
if isinstance(target.slice, ast.Constant):
|
||||
key = target.slice.value
|
||||
if self.current_provider_block not in self.param_assignments:
|
||||
self.param_assignments[self.current_provider_block] = set()
|
||||
self.param_assignments[self.current_provider_block].add(key)
|
||||
self.generic_visit(node)
|
||||
|
||||
def _extract_provider_from_if(self, node: ast.If) -> Optional[str]:
|
||||
"""Extract the provider name from an if condition checking custom_llm_provider"""
|
||||
if isinstance(node.test, ast.Compare):
|
||||
if len(node.test.ops) == 1 and isinstance(node.test.ops[0], ast.Eq):
|
||||
if (
|
||||
isinstance(node.test.left, ast.Name)
|
||||
and node.test.left.id == "custom_llm_provider"
|
||||
):
|
||||
if isinstance(node.test.comparators[0], ast.Constant):
|
||||
return node.test.comparators[0].value
|
||||
return None
|
||||
|
||||
def check_patterns(self) -> List[str]:
|
||||
# Check if all configs using map_openai_params inherit from BaseConfig
|
||||
for config_name in self.map_openai_calls:
|
||||
print(f"Checking config: {config_name}")
|
||||
if (
|
||||
config_name not in self.class_inheritance
|
||||
or "BaseConfig" not in self.class_inheritance[config_name]
|
||||
):
|
||||
# Retrieve the associated class name, if any
|
||||
class_name = next(
|
||||
(
|
||||
cls
|
||||
for cls, bases in self.class_inheritance.items()
|
||||
if config_name in bases
|
||||
),
|
||||
"Unknown Class",
|
||||
)
|
||||
self.errors.append(
|
||||
f"Error: {config_name} calls map_openai_params but doesn't inherit from BaseConfig. "
|
||||
f"It is used in the class: {class_name}"
|
||||
)
|
||||
|
||||
# Check for parameter assignments in provider blocks
|
||||
for provider, params in self.param_assignments.items():
|
||||
# You can customize which parameters should raise warnings for each provider
|
||||
for param in params:
|
||||
if param not in self._get_allowed_params(provider):
|
||||
self.errors.append(
|
||||
f"Warning: Parameter '{param}' is directly assigned in {provider} block. "
|
||||
f"Consider using a config class instead."
|
||||
)
|
||||
|
||||
return self.errors
|
||||
|
||||
def _get_allowed_params(self, provider: str) -> Set[str]:
|
||||
"""Define allowed direct parameter assignments for each provider"""
|
||||
# You can customize this based on your requirements
|
||||
common_allowed = {"stream", "api_key", "api_base"}
|
||||
provider_specific = {
|
||||
"anthropic": {"api_version"},
|
||||
"openai": {"organization"},
|
||||
# Add more providers and their allowed params here
|
||||
}
|
||||
return common_allowed.union(provider_specific.get(provider, set()))
|
||||
|
||||
|
||||
def check_file(file_path: str) -> List[str]:
|
||||
with open(file_path, "r") as file:
|
||||
tree = ast.parse(file.read())
|
||||
|
||||
checker = ConfigChecker()
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.FunctionDef) and node.name == "get_optional_params":
|
||||
checker.visit(node)
|
||||
break # No need to visit other functions
|
||||
return checker.check_patterns()
|
||||
|
||||
|
||||
def main():
|
||||
file_path = "../../litellm/utils.py"
|
||||
errors = check_file(file_path)
|
||||
|
||||
if errors:
|
||||
print("\nFound the following issues:")
|
||||
for error in errors:
|
||||
print(f"- {error}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("No issues found!")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user