in fairscale/experimental/nn/auto_shard.py [0:0]
def _split_nodes(traced_graph_module: torch.fx.GraphModule, shard_count: int = 3) -> Dict:
"""Utility used to trace a graph and identify shard cutpoints."""
node_name_to_shard_id: Dict[str, int] = {}
shard_id = 0
nodes_so_far = []
param_count: Dict[str, int] = {}
shard_to_param_count = {}
# Find the total number of params in the model and
# the number of params per shard we are aiming for.
for name, module in traced_graph_module.named_modules():
name = name.replace(".", "_")
param_count[name] = sum([x.numel() for x in module.parameters()])
logging.info(f"Total number of params are {param_count['']}")
per_shard_param = param_count[""] // shard_count
logging.info(f"Per shard param count {per_shard_param}")
for node in traced_graph_module.graph.nodes:
if node.op == "placeholder":
node_name_to_shard_id[node.name] = shard_id
nodes_so_far.append(node.name)
elif node.op in ["get_attr", "call_function", "call_method", "call_module"]:
min_shard_id = shard_id
min_node_name = ""
# For each of the args of a given node, find the arg that is not the
# last node we traversed. This is to help us find skip connections
# across shards.
for arg in node.args:
# If the node has args that are inputs to the forward function, they
# may not have explicit names.
if not hasattr(arg, "name"):
continue
if arg.name in node_name_to_shard_id and arg.name != nodes_so_far[-1]:
if node_name_to_shard_id[arg.name] < min_shard_id:
min_shard_id = node_name_to_shard_id[arg.name]
min_node_name = arg.name
# If there is an input that is not from the previous shard,
# we collapse all the shards in between to be part of 1 shard.
# and update the param count per shard accordingly.
if min_shard_id < shard_id:
for node_name in reversed(nodes_so_far):
node_name_to_shard_id[node_name] = min_shard_id
if node_name == min_node_name:
break
shard_id = min_shard_id
# TODO(anj-s): Find a way to raise an error early if this can cause OOM errors.
shard_to_param_count = _create_shard_to_param_count(param_count, node_name_to_shard_id)
# Update state that is tracking node -> shard id and shard id -> param count.
node_name_to_shard_id[node.name] = shard_id
nodes_so_far.append(node.name)
# TODO(anj): This could just be an update, we don't need to recreate the map.
shard_to_param_count = _create_shard_to_param_count(param_count, node_name_to_shard_id)
# If we have gone over the number of params per shard count that we want to
# achieve, we should add a new shard.
# The shard_id may not have been updated in the map if we are at a node that does not
# have params.
if shard_id in shard_to_param_count and shard_to_param_count[shard_id] > per_shard_param:
shard_id += 1
elif node.op == "output":
break
return node_name_to_shard_id