in torchrec/linter/module_linter.py [0:0]
def check_class_definition(python_path: str, node: ast.ClassDef) -> None:
"""
This function will run set of sanity checks against class definitions
and their docstrings.
Args:
python_path: Path to the file that is getting checked
node: AST node with the ClassDef that needs to be checked
Returns:
None
"""
assert (
type(node) == ast.ClassDef
), "Received invalid node type. Expected ClassDef, got: {}".format(type(node))
is_TorchRec_module = False
is_test_file = "tests" in python_path
for base in node.bases:
# For now only names and attributes are supported
if type(base) != ast.Name and type(base) != ast.Attribute: # pragma: nocover
continue
# We assume that TorchRec module has one of the following inheritance patterns:
# 1. `class SomeTorchRecModule(LazyModuleExtensionMixin, torch.nn.Module)`
# 2. `class SomeTorchRecModule(torch.nn.Module)`
# pyre-ignore[16]: `_ast.expr` has no attribute `id`.
if hasattr(base, "id") and base.id == "LazyModuleExtensionMixin":
is_TorchRec_module = True
break
# pyre-ignore[16]: `_ast.expr` has no attribute `id`.
elif hasattr(base, "attr") and base.attr == "Module":
is_TorchRec_module = True
break
if not is_TorchRec_module or is_test_file:
return
docstring: Optional[str] = ast.get_docstring(node)
if docstring is None:
print_error_message(
python_path,
node,
"No docstring found in a TorchRec module",
"TorchRec modules are required to have a docstring describing how "
"to use them. Given Module don't have a docstring, please fix this.",
)
return
# Check presence of the example:
if "Example:" not in docstring or ">>> " not in docstring:
print_error_message(
python_path,
node,
"No runnable example in a TorchRec module",
"TorchRec modules are required to have runnable examples in "
'"Example:" section, that start from ">>> ". Please fix the docstring',
)
# Check correctness of the Args for a class definition:
required_keywords = ["Constructor Args:", "Call Args:", "Returns:"]
missing_keywords = []
for keyword in required_keywords:
if keyword not in docstring:
missing_keywords.append(keyword)
if len(missing_keywords) > 0:
print_error_message(
python_path,
node,
"Missing required keywords from TorchRec module",
"TorchRec modules are required to description of their args and "
'results in "Constructor Args:", "Call Args:", "Returns:". '
"Missing keywords: {}.".format(missing_keywords),
)
# Check actual args from the functions
# pyre-ignore[33]: Explicit annotation for `functions` cannot contain `Any`.
functions: Dict[str, Tuple[List[Any], List[Any]]] = {}
for sub_node in node.body:
if type(sub_node) == ast.FunctionDef:
assert isinstance(sub_node, ast.FunctionDef)
functions[sub_node.name] = get_function_args(sub_node)
def check_function(function_name: str) -> None:
if function_name not in functions:
return
if function_name == "__init__":
# NOTE: -1 to not count the `self` argument.
num_args = sum([len(args) for args in functions[function_name]]) - 1
if num_args > MAX_NUM_ARGS_IN_MODULE_CTOR:
print_error_message(
python_path,
node,
"TorchRec module has too many constructor arguments",
"TorchRec module can have at most {} constructor arguments, but this module has {}.".format(
MAX_NUM_ARGS_IN_MODULE_CTOR,
len(functions[function_name][1]),
),
)
if function_name in functions:
missing_required_args = []
missing_optional_args = []
for arg in functions[function_name][0]:
# Ignore checks for required self and net args
if arg == "self" or arg == "net":
continue
assert docstring is not None
if arg not in docstring:
missing_required_args.append(arg)
for arg in functions[function_name][1]:
assert docstring is not None
if arg not in docstring:
missing_optional_args.append(arg)
if len(missing_required_args) > 0 or len(missing_optional_args) > 0:
print_error_message(
python_path,
node,
"Missing docstring descriptions for {} function arguments.".format(
function_name
),
(
"Missing descriptions for {} function arguments. "
"Missing required args: {}, missing optional args: {}"
).format(
function_name, missing_required_args, missing_optional_args
),
)
check_function("__init__")
check_function("forward")