def _get_node_args_helper()

in torchrec/distributed/train_pipeline.py [0:0]


def _get_node_args_helper(arguments, num_found: int) -> Tuple[List[ArgInfo], int]:
    """
    Goes through the args/kwargs of a node and arranges them into a list of ArgInfos.
    It also counts the number of (args + kwargs) found.
    """

    arg_info_list = [ArgInfo([], None) for _ in range(len(arguments))]
    for arg, arg_info in zip(arguments, arg_info_list):
        if arg is None:
            num_found += 1
            continue
        while True:
            if not isinstance(arg, torch.fx.Node):
                break
            child_node = arg

            if child_node.op == "placeholder":
                num_found += 1
                break
            elif (
                child_node.op == "call_function"
                and child_node.target.__module__ == "builtins"
                # pyre-ignore [16]
                and child_node.target.__name__ == "getattr"
            ):
                arg_info.input_attrs.insert(0, child_node.args[1])
                arg = child_node.args[0]
            else:
                break
    return arg_info_list, num_found