in optimum/fx/parallelization/passes.py [0:0]
def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule:
for node in graph_module.graph.nodes:
if is_linear(node):
axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis")
axis_after = ParallelAxisSolverPass.get_stored_field_info(node, "parallel_axis")
info = {}
if axis_before is None:
info["axis"] = "column"
info["gather_output"] = True if axis_after is None else False
elif axis_before == 1:
assert (
config.enable_sequence_parallel
), "illegal parallel axis for sequence parallelism deactivated setting"
info["axis"] = "column"
info["sequence_parallel"] = True
info["gather_output"] = True if axis_after is None else False
elif axis_before == 2:
info["axis"] = "row"
info["input_is_parallel"] = True
if axis_after == 1:
assert (
config.enable_sequence_parallel
), "illegal parallel axis for sequence parallelism deactivated setting"
info["sequence_parallel"] = True
else:
info["sequence_parallel"] = False
self.place_marker_per_node(node, info)
elif is_embedding(node):
axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis")
axis_after = ParallelAxisSolverPass.get_stored_field_info(node, "parallel_axis")
assert axis_before is None and axis_after in [1, None]
info = {"axis": "vocab"}
if axis_after == 1:
assert (
config.enable_sequence_parallel
), "illegal parallel axis for sequence parallelism deactivated setting"
info["sequence_parallel"] = True
else:
info["sequence_parallel"] = False
self.place_marker_per_node(node, info)
elif is_cross_entropy(node):
axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis")
if axis_before is not None:
self.place_marker_per_node(node, {"axis": "vocab"})
return graph_module