def run()

in optimum/fx/parallelization/passes.py [0:0]


    def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule:
        world_size = dist.get_world_size(ctx.tp_group)
        tp_rank = dist.get_rank(ctx.tp_group)

        new_parameters, tied_parameters, param_cache = [], {}, ctx.param_cache
        for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)):
            # skip initializing new params when recompilation happens
            if name in param_cache:
                new_parameters.append((name, param_cache[name]))
                continue

            param_meta: ParameterMeta = getattr(param, "meta")
            # skip already initialized/loaded tied parameters
            if param_meta.is_tied and id(param) in tied_parameters:
                new_parameters.append((name, tied_parameters[id(param)]))
                continue

            shape = [
                param.size(dim) // world_size if dim == param_meta.dim and param_meta.is_parallel else param.size(dim)
                for dim in range(param.ndim)
            ]

            if not param_meta.is_parallel and param.device == ctx.current_device:
                new_param = param
            else:
                new_param = nn.Parameter(
                    torch.zeros(*shape, dtype=param.dtype, device=ctx.current_device),
                    requires_grad=param.requires_grad,
                )

            # load weights if possible
            for source, target in sorted(param_meta.mapping.items()):
                if target.source in ctx.weight_map:
                    from safetensors import safe_open

                    with safe_open(ctx.weight_map[target.source], framework="pt", device="cpu") as fp:
                        tensor_slice = fp.get_slice(target.source)
                        source_index = [
                            source.to_slice() if dim == param_meta.dim else slice(None, None, None)
                            for dim in range(param.ndim)
                        ]
                        load_index = [
                            target.index if dim == param_meta.dim else slice(None, None, None)
                            for dim in range(param.ndim)
                        ]

                        tensor = tensor_slice[load_index].contiguous()
                        tensor = torch.empty_like(tensor).copy_(tensor)
                        with torch.no_grad():
                            new_param.data[source_index].copy_(tensor)

            # weights initialization
            if param_meta.need_initialize:
                for source, target in sorted(param_meta.mapping.items()):
                    if target.source in ctx.weight_map:
                        continue
                    if not param_meta.is_parallel or tp_rank == 0:
                        # initialize weight on master rank
                        weight = torch.empty(*target.shape, dtype=param.dtype, device="cpu")
                        init_fn = param_meta.init_fn if param_meta.init_fn else config.weight_init_fn
                        init_fn(weight)
                        weight = weight.to(ctx.current_device)
                    else:
                        weight = None
                    index = [
                        source.to_slice() if dim == param_meta.dim else slice(None, None, None)
                        for dim in range(param.ndim)
                    ]
                    with torch.no_grad():
                        if param_meta.is_parallel:
                            scatter(ctx.tp_group, weight, new_param.data[index], scatter_dim=param_meta.dim)
                        else:
                            new_param.data[index].copy_(weight)
            setattr(new_param, "meta", param_meta)

            if id(new_param) != id(param):
                new_parameters.append((name, new_param))
            if param_meta.is_tied:
                tied_parameters[id(param)] = new_param

        for name, new_param in new_parameters:
            prefix_and_field = name.rsplit(".", maxsplit=1)
            if len(prefix_and_field) == 2:
                parent_mod = graph_module.get_submodule(prefix_and_field[0])
                field = prefix_and_field[1]
            else:
                parent_mod = graph_module
                field = name
            if name not in param_cache:
                param_cache[name] = new_param
            setattr(parent_mod, field, new_param)

        return graph_module