def propagate()

in optimum/fx/parallelization/op_registry/op_handlers.py [0:0]


    def propagate(self) -> List[int]:
        # by default we don't parallelize inputs and constants(except parameters embeded in modules)
        if self.node.op in ["placeholder", "get_attr"]:
            return [None]
        elif self.node.op == "output":
            # does not care about if output is being parallelized right now, because if the output is loss,
            # then it must be not parallelized as long as it comes from sharded cross entropy.
            # TODO: append all-gather comm ops before all parallelized output nodes if instructed.
            input_arg = self.node.all_input_nodes[0]
            axis = self.extract_axis(input_arg)
            return [axis]
        elif is_linear(self.node):
            input_arg = self.node.all_input_nodes[0]
            axis = self.extract_axis(input_arg)
            if axis is None:
                # with input being not parallelized, output can be parallelized on the head dimension,
                # i.e., `ColumnLinear`, or not being parallelized by all-gather at the end
                return [2, None]
            elif self.config.enable_sequence_parallel and axis == 1:
                # with input being parallelized on sequence dimension, output can be parallelized on
                # the head dimension, i.e., `ColumnLinear` with sequence parallel, or not being parallelized
                # by all-gather at the end
                return [2, None]
            elif axis == 2:
                # with input being parallelized on head dimension, output can be parallelized on the
                # sequence dimension or not parallelized by all-reduce at the end, i.e., `RowLinear`
                # when sp is not enabled
                return [1, None] if self.config.enable_sequence_parallel else [None]
            else:
                return []
        elif is_embedding(self.node):
            input_arg = self.node.all_input_nodes[0]
            axis = self.extract_axis(input_arg)
            if axis is None:
                # only support the embedding parameter being parallelized on `vocab` dim or not parallelized for now,
                # the output can be parallelized on sequence dim or not parallelized
                return [1, None] if self.config.enable_sequence_parallel else [None]
            else:
                return []
        elif is_cross_entropy(self.node):
            logits = self.node.all_input_nodes[0]
            axis = self.extract_axis(logits)
            if axis is None or (
                is_cross_entropy_parallel_compatible(self.node) and axis == logits.meta["val"].ndim - 1
            ):
                # for cross entropy, the input logits parallel axis can only be the last axis or None
                return [None]
            else:
                return []
        elif is_activation(self.node):
            return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate()

        # last resort, if no input is being parallelized, then we make output also not parallelized,
        # this will give us relief on writing policies for strange ops which don't actually need
        # parallelization in most cases
        if all(self.extract_axis(arg) is None for arg in self.node.all_input_nodes):
            return [None]

        raise NotImplementedError(f"don't know how to propagate axis for {self.node.target}")