in optimum/fx/parallelization/op_registry/op_handlers.py [0:0]
def propagate(self) -> List[int]:
arg, slice_dim = self.node.args[0], self.node.args[1]
axis = self.extract_axis(arg)
if axis is None:
return [None]
ndim = arg.meta["val"].ndim
slice_dim = (slice_dim + ndim) % ndim
if slice_dim == axis:
# slice on the parallel axis is not allowed, except it's a nop
start, stop, step = 0, arg.meta["val"].shape[axis], 1
if len(self.node.args) > 2:
start = self.node.args[2]
elif len(self.node.args) > 3:
stop = self.node.args[3]
elif len(self.node.args) > 4:
step = self.node.args[4]
if start == 0 and stop >= arg.meta["val"].shape[axis] and step == 1:
return [axis]
return []
return [axis]