def propagate()

in optimum/fx/parallelization/op_registry/op_handlers.py [0:0]


    def propagate(self) -> List[int]:
        input_nodes = self.node.all_input_nodes
        # only one node
        if len(input_nodes) == 1:
            return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate()

        assert len(input_nodes) == 2, "binary op should have exact two nodes as inputs"
        lhs_shape, rhs_shape = input_nodes[0].meta["val"].shape, input_nodes[1].meta["val"].shape
        lhs_axis = self.extract_axis(input_nodes[0])
        rhs_axis = self.extract_axis(input_nodes[1])
        i, j = len(lhs_shape) - 1, len(rhs_shape) - 1
        while i >= 0 and j >= 0:
            k = max(lhs_shape[i], rhs_shape[j])
            assert (
                k % min(lhs_shape[i], rhs_shape[j]) == 0
            ), f"shape {lhs_shape} and {rhs_shape} are not broadcastable!"
            i -= 1
            j -= 1

        if i < 0 and lhs_axis is not None:
            lhs_axis += j + 1
        if j < 0 and rhs_axis is not None:
            rhs_axis += i + 1

        if lhs_axis is None:
            return [rhs_axis]
        elif rhs_axis is None:
            return [lhs_axis]
        elif lhs_axis != rhs_axis:
            return []
        return [lhs_axis]