in src/smolagents/tool_validation.py [0:0]
def validate_tool_attributes(cls, check_imports: bool = True) -> None:
"""
Validates that a Tool class follows the proper patterns:
0. Any argument of __init__ should have a default.
Args chosen at init are not traceable, so we cannot rebuild the source code for them, thus any important arg should be defined as a class attribute.
1. About the class:
- Class attributes should only be strings or dicts
- Class attributes cannot be complex attributes
2. About all class methods:
- Imports must be from packages, not local files
- All methods must be self-contained
Raises all errors encountered, if no error returns None.
"""
class ClassLevelChecker(ast.NodeVisitor):
def __init__(self):
self.imported_names = set()
self.complex_attributes = set()
self.class_attributes = set()
self.non_defaults = set()
self.non_literal_defaults = set()
self.in_method = False
self.invalid_attributes = []
def visit_FunctionDef(self, node):
if node.name == "__init__":
self._check_init_function_parameters(node)
old_context = self.in_method
self.in_method = True
self.generic_visit(node)
self.in_method = old_context
def visit_Assign(self, node):
if self.in_method:
return
# Track class attributes
for target in node.targets:
if isinstance(target, ast.Name):
self.class_attributes.add(target.id)
# Check if the assignment is more complex than simple literals
if not all(
isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set))
for val in ast.walk(node.value)
):
for target in node.targets:
if isinstance(target, ast.Name):
self.complex_attributes.add(target.id)
# Check specific class attributes
if getattr(node.targets[0], "id", "") == "name":
if not isinstance(node.value, ast.Constant):
self.invalid_attributes.append(f"Class attribute 'name' must be a constant, found '{node.value}'")
elif not isinstance(node.value.value, str):
self.invalid_attributes.append(
f"Class attribute 'name' must be a string, found '{node.value.value}'"
)
elif not is_valid_name(node.value.value):
self.invalid_attributes.append(
f"Class attribute 'name' must be a valid Python identifier and not a reserved keyword, found '{node.value.value}'"
)
def _check_init_function_parameters(self, node):
# Check defaults in parameters
for arg, default in reversed(list(zip_longest(reversed(node.args.args), reversed(node.args.defaults)))):
if default is None:
if arg.arg != "self":
self.non_defaults.add(arg.arg)
elif not isinstance(default, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)):
self.non_literal_defaults.add(arg.arg)
class_level_checker = ClassLevelChecker()
source = get_source(cls)
tree = ast.parse(source)
class_node = tree.body[0]
if not isinstance(class_node, ast.ClassDef):
raise ValueError("Source code must define a class")
class_level_checker.visit(class_node)
errors = []
# Check invalid class attributes
if class_level_checker.invalid_attributes:
errors += class_level_checker.invalid_attributes
if class_level_checker.complex_attributes:
errors.append(
f"Complex attributes should be defined in __init__, not as class attributes: "
f"{', '.join(class_level_checker.complex_attributes)}"
)
if class_level_checker.non_defaults:
errors.append(
f"Parameters in __init__ must have default values, found required parameters: "
f"{', '.join(class_level_checker.non_defaults)}"
)
if class_level_checker.non_literal_defaults:
errors.append(
f"Parameters in __init__ must have literal default values, found non-literal defaults: "
f"{', '.join(class_level_checker.non_literal_defaults)}"
)
# Run checks on all methods
for node in class_node.body:
if isinstance(node, ast.FunctionDef):
method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports)
method_checker.visit(node)
errors += [f"- {node.name}: {error}" for error in method_checker.errors]
if errors:
raise ValueError(f"Tool validation failed for {cls.__name__}:\n" + "\n".join(errors))
return