def check_class_definition()

in torchrec/linter/module_linter.py [0:0]


def check_class_definition(python_path: str, node: ast.ClassDef) -> None:
    """
    This function will run set of sanity checks against class definitions
    and their docstrings.

    Args:
        python_path: Path to the file that is getting checked
        node: AST node with the ClassDef that needs to be checked

    Returns:
        None
    """
    assert (
        type(node) == ast.ClassDef
    ), "Received invalid node type. Expected ClassDef, got: {}".format(type(node))

    is_TorchRec_module = False
    is_test_file = "tests" in python_path
    for base in node.bases:
        # For now only names and attributes are supported
        if type(base) != ast.Name and type(base) != ast.Attribute:  # pragma: nocover
            continue

        # We assume that TorchRec module has one of the following inheritance patterns:
        # 1. `class SomeTorchRecModule(LazyModuleExtensionMixin, torch.nn.Module)`
        # 2. `class SomeTorchRecModule(torch.nn.Module)`
        # pyre-ignore[16]: `_ast.expr` has no attribute `id`.
        if hasattr(base, "id") and base.id == "LazyModuleExtensionMixin":
            is_TorchRec_module = True
            break
        # pyre-ignore[16]: `_ast.expr` has no attribute `id`.
        elif hasattr(base, "attr") and base.attr == "Module":
            is_TorchRec_module = True
            break

    if not is_TorchRec_module or is_test_file:
        return

    docstring: Optional[str] = ast.get_docstring(node)
    if docstring is None:
        print_error_message(
            python_path,
            node,
            "No docstring found in a TorchRec module",
            "TorchRec modules are required to have a docstring describing how "
            "to use them. Given Module don't have a docstring, please fix this.",
        )
        return

    # Check presence of the example:
    if "Example:" not in docstring or ">>> " not in docstring:
        print_error_message(
            python_path,
            node,
            "No runnable example in a TorchRec module",
            "TorchRec modules are required to have runnable examples in "
            '"Example:" section, that start from ">>> ". Please fix the docstring',
        )

    # Check correctness of the Args for a class definition:
    required_keywords = ["Constructor Args:", "Call Args:", "Returns:"]
    missing_keywords = []
    for keyword in required_keywords:
        if keyword not in docstring:
            missing_keywords.append(keyword)

    if len(missing_keywords) > 0:
        print_error_message(
            python_path,
            node,
            "Missing required keywords from TorchRec module",
            "TorchRec modules are required to description of their args and "
            'results in "Constructor Args:", "Call Args:", "Returns:". '
            "Missing keywords: {}.".format(missing_keywords),
        )

    # Check actual args from the functions
    # pyre-ignore[33]: Explicit annotation for `functions` cannot contain `Any`.
    functions: Dict[str, Tuple[List[Any], List[Any]]] = {}
    for sub_node in node.body:
        if type(sub_node) == ast.FunctionDef:
            assert isinstance(sub_node, ast.FunctionDef)
            functions[sub_node.name] = get_function_args(sub_node)

    def check_function(function_name: str) -> None:
        if function_name not in functions:
            return

        if function_name == "__init__":
            # NOTE: -1 to not count the `self` argument.
            num_args = sum([len(args) for args in functions[function_name]]) - 1
            if num_args > MAX_NUM_ARGS_IN_MODULE_CTOR:
                print_error_message(
                    python_path,
                    node,
                    "TorchRec module has too many constructor arguments",
                    "TorchRec module can have at most {} constructor arguments, but this module has {}.".format(
                        MAX_NUM_ARGS_IN_MODULE_CTOR,
                        len(functions[function_name][1]),
                    ),
                )
        if function_name in functions:
            missing_required_args = []
            missing_optional_args = []
            for arg in functions[function_name][0]:
                # Ignore checks for required self and net args
                if arg == "self" or arg == "net":
                    continue
                assert docstring is not None
                if arg not in docstring:
                    missing_required_args.append(arg)
            for arg in functions[function_name][1]:
                assert docstring is not None
                if arg not in docstring:
                    missing_optional_args.append(arg)
            if len(missing_required_args) > 0 or len(missing_optional_args) > 0:
                print_error_message(
                    python_path,
                    node,
                    "Missing docstring descriptions for {} function arguments.".format(
                        function_name
                    ),
                    (
                        "Missing descriptions for {} function arguments. "
                        "Missing required args: {}, missing optional args: {}"
                    ).format(
                        function_name, missing_required_args, missing_optional_args
                    ),
                )

    check_function("__init__")
    check_function("forward")