def propagate()

in optimum/fx/parallelization/op_registry/op_handlers.py [0:0]


    def propagate(self) -> List[int]:
        arg, slice_dim = self.node.args[0], self.node.args[1]
        axis = self.extract_axis(arg)
        if axis is None:
            return [None]
        ndim = arg.meta["val"].ndim
        slice_dim = (slice_dim + ndim) % ndim
        if slice_dim == axis:
            # slice on the parallel axis is not allowed, except it's a nop
            start, stop, step = 0, arg.meta["val"].shape[axis], 1
            if len(self.node.args) > 2:
                start = self.node.args[2]
            elif len(self.node.args) > 3:
                stop = self.node.args[3]
            elif len(self.node.args) > 4:
                step = self.node.args[4]
            if start == 0 and stop >= arg.meta["val"].shape[axis] and step == 1:
                return [axis]
            return []
        return [axis]