in optimum/fx/parallelization/op_registry/op_handlers.py [0:0]
def propagate(self) -> List[int]:
# by default we don't parallelize inputs and constants(except parameters embeded in modules)
if self.node.op in ["placeholder", "get_attr"]:
return [None]
elif self.node.op == "output":
# does not care about if output is being parallelized right now, because if the output is loss,
# then it must be not parallelized as long as it comes from sharded cross entropy.
# TODO: append all-gather comm ops before all parallelized output nodes if instructed.
input_arg = self.node.all_input_nodes[0]
axis = self.extract_axis(input_arg)
return [axis]
elif is_linear(self.node):
input_arg = self.node.all_input_nodes[0]
axis = self.extract_axis(input_arg)
if axis is None:
# with input being not parallelized, output can be parallelized on the head dimension,
# i.e., `ColumnLinear`, or not being parallelized by all-gather at the end
return [2, None]
elif self.config.enable_sequence_parallel and axis == 1:
# with input being parallelized on sequence dimension, output can be parallelized on
# the head dimension, i.e., `ColumnLinear` with sequence parallel, or not being parallelized
# by all-gather at the end
return [2, None]
elif axis == 2:
# with input being parallelized on head dimension, output can be parallelized on the
# sequence dimension or not parallelized by all-reduce at the end, i.e., `RowLinear`
# when sp is not enabled
return [1, None] if self.config.enable_sequence_parallel else [None]
else:
return []
elif is_embedding(self.node):
input_arg = self.node.all_input_nodes[0]
axis = self.extract_axis(input_arg)
if axis is None:
# only support the embedding parameter being parallelized on `vocab` dim or not parallelized for now,
# the output can be parallelized on sequence dim or not parallelized
return [1, None] if self.config.enable_sequence_parallel else [None]
else:
return []
elif is_cross_entropy(self.node):
logits = self.node.all_input_nodes[0]
axis = self.extract_axis(logits)
if axis is None or (
is_cross_entropy_parallel_compatible(self.node) and axis == logits.meta["val"].ndim - 1
):
# for cross entropy, the input logits parallel axis can only be the last axis or None
return [None]
else:
return []
elif is_activation(self.node):
return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate()
# last resort, if no input is being parallelized, then we make output also not parallelized,
# this will give us relief on writing policies for strange ops which don't actually need
# parallelization in most cases
if all(self.extract_axis(arg) is None for arg in self.node.all_input_nodes):
return [None]
raise NotImplementedError(f"don't know how to propagate axis for {self.node.target}")