in optimum/fx/parallelization/passes.py [0:0]
def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule:
world_size = dist.get_world_size(ctx.tp_group)
tp_rank = dist.get_rank(ctx.tp_group)
new_parameters, tied_parameters, param_cache = [], {}, ctx.param_cache
for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)):
# skip initializing new params when recompilation happens
if name in param_cache:
new_parameters.append((name, param_cache[name]))
continue
param_meta: ParameterMeta = getattr(param, "meta")
# skip already initialized/loaded tied parameters
if param_meta.is_tied and id(param) in tied_parameters:
new_parameters.append((name, tied_parameters[id(param)]))
continue
shape = [
param.size(dim) // world_size if dim == param_meta.dim and param_meta.is_parallel else param.size(dim)
for dim in range(param.ndim)
]
if not param_meta.is_parallel and param.device == ctx.current_device:
new_param = param
else:
new_param = nn.Parameter(
torch.zeros(*shape, dtype=param.dtype, device=ctx.current_device),
requires_grad=param.requires_grad,
)
# load weights if possible
for source, target in sorted(param_meta.mapping.items()):
if target.source in ctx.weight_map:
from safetensors import safe_open
with safe_open(ctx.weight_map[target.source], framework="pt", device="cpu") as fp:
tensor_slice = fp.get_slice(target.source)
source_index = [
source.to_slice() if dim == param_meta.dim else slice(None, None, None)
for dim in range(param.ndim)
]
load_index = [
target.index if dim == param_meta.dim else slice(None, None, None)
for dim in range(param.ndim)
]
tensor = tensor_slice[load_index].contiguous()
tensor = torch.empty_like(tensor).copy_(tensor)
with torch.no_grad():
new_param.data[source_index].copy_(tensor)
# weights initialization
if param_meta.need_initialize:
for source, target in sorted(param_meta.mapping.items()):
if target.source in ctx.weight_map:
continue
if not param_meta.is_parallel or tp_rank == 0:
# initialize weight on master rank
weight = torch.empty(*target.shape, dtype=param.dtype, device="cpu")
init_fn = param_meta.init_fn if param_meta.init_fn else config.weight_init_fn
init_fn(weight)
weight = weight.to(ctx.current_device)
else:
weight = None
index = [
source.to_slice() if dim == param_meta.dim else slice(None, None, None)
for dim in range(param.ndim)
]
with torch.no_grad():
if param_meta.is_parallel:
scatter(ctx.tp_group, weight, new_param.data[index], scatter_dim=param_meta.dim)
else:
new_param.data[index].copy_(weight)
setattr(new_param, "meta", param_meta)
if id(new_param) != id(param):
new_parameters.append((name, new_param))
if param_meta.is_tied:
tied_parameters[id(param)] = new_param
for name, new_param in new_parameters:
prefix_and_field = name.rsplit(".", maxsplit=1)
if len(prefix_and_field) == 2:
parent_mod = graph_module.get_submodule(prefix_and_field[0])
field = prefix_and_field[1]
else:
parent_mod = graph_module
field = name
if name not in param_cache:
param_cache[name] = new_param
setattr(parent_mod, field, new_param)
return graph_module