def run()

in optimum/fx/parallelization/passes.py [0:0]


    def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule:
        for node in graph_module.graph.nodes:
            if is_linear(node):
                axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis")
                axis_after = ParallelAxisSolverPass.get_stored_field_info(node, "parallel_axis")
                info = {}
                if axis_before is None:
                    info["axis"] = "column"
                    info["gather_output"] = True if axis_after is None else False
                elif axis_before == 1:
                    assert (
                        config.enable_sequence_parallel
                    ), "illegal parallel axis for sequence parallelism deactivated setting"
                    info["axis"] = "column"
                    info["sequence_parallel"] = True
                    info["gather_output"] = True if axis_after is None else False
                elif axis_before == 2:
                    info["axis"] = "row"
                    info["input_is_parallel"] = True
                    if axis_after == 1:
                        assert (
                            config.enable_sequence_parallel
                        ), "illegal parallel axis for sequence parallelism deactivated setting"
                        info["sequence_parallel"] = True
                    else:
                        info["sequence_parallel"] = False
                self.place_marker_per_node(node, info)

            elif is_embedding(node):
                axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis")
                axis_after = ParallelAxisSolverPass.get_stored_field_info(node, "parallel_axis")
                assert axis_before is None and axis_after in [1, None]
                info = {"axis": "vocab"}
                if axis_after == 1:
                    assert (
                        config.enable_sequence_parallel
                    ), "illegal parallel axis for sequence parallelism deactivated setting"
                    info["sequence_parallel"] = True
                else:
                    info["sequence_parallel"] = False
                self.place_marker_per_node(node, info)

            elif is_cross_entropy(node):
                axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis")
                if axis_before is not None:
                    self.place_marker_per_node(node, {"axis": "vocab"})

        return graph_module