in torchrec/distributed/train_pipeline.py [0:0]
def _get_node_args_helper(arguments, num_found: int) -> Tuple[List[ArgInfo], int]:
"""
Goes through the args/kwargs of a node and arranges them into a list of ArgInfos.
It also counts the number of (args + kwargs) found.
"""
arg_info_list = [ArgInfo([], None) for _ in range(len(arguments))]
for arg, arg_info in zip(arguments, arg_info_list):
if arg is None:
num_found += 1
continue
while True:
if not isinstance(arg, torch.fx.Node):
break
child_node = arg
if child_node.op == "placeholder":
num_found += 1
break
elif (
child_node.op == "call_function"
and child_node.target.__module__ == "builtins"
# pyre-ignore [16]
and child_node.target.__name__ == "getattr"
):
arg_info.input_attrs.insert(0, child_node.args[1])
arg = child_node.args[0]
else:
break
return arg_info_list, num_found