in optimum/fx/parallelization/op_registry/op_handlers.py [0:0]
def propagate(self) -> List[int]:
input_nodes = self.node.all_input_nodes
# only one node
if len(input_nodes) == 1:
return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate()
assert len(input_nodes) == 2, "binary op should have exact two nodes as inputs"
lhs_shape, rhs_shape = input_nodes[0].meta["val"].shape, input_nodes[1].meta["val"].shape
lhs_axis = self.extract_axis(input_nodes[0])
rhs_axis = self.extract_axis(input_nodes[1])
i, j = len(lhs_shape) - 1, len(rhs_shape) - 1
while i >= 0 and j >= 0:
k = max(lhs_shape[i], rhs_shape[j])
assert (
k % min(lhs_shape[i], rhs_shape[j]) == 0
), f"shape {lhs_shape} and {rhs_shape} are not broadcastable!"
i -= 1
j -= 1
if i < 0 and lhs_axis is not None:
lhs_axis += j + 1
if j < 0 and rhs_axis is not None:
rhs_axis += i + 1
if lhs_axis is None:
return [rhs_axis]
elif rhs_axis is None:
return [lhs_axis]
elif lhs_axis != rhs_axis:
return []
return [lhs_axis]