in src/accelerate/utils/fsdp_utils.py [0:0]
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
"""
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
parameters from rank 0 to all other ranks. This function modifies the model in-place.
Args:
accelerator (`Accelerator`): The accelerator instance
model (`torch.nn.Module`):
The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
full_sd (`dict`): The full state dict to load, can only be on rank 0
"""
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor
# Model was previously copied to meta device
meta_sharded_sd = model.state_dict()
sharded_sd = {}
# Rank 0 distributes the full state dict to other ranks
def _infer_parameter_dtype(model, param_name, empty_param):
try:
old_param = model.get_parameter_or_buffer(param_name)
except AttributeError:
# Need this for LORA, as there some params are not *parameters* of sorts
base_param_name, local_param_name = param_name.rsplit(".", 1)
submodule = model.get_submodule(base_param_name)
old_param = getattr(submodule, local_param_name)
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
casting_dtype = None
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
casting_dtype = old_param.dtype
return old_param is not None and old_param.is_contiguous(), casting_dtype
def _cast_and_contiguous(tensor, to_contiguous, dtype):
if dtype is not None:
tensor = tensor.to(dtype=dtype)
if to_contiguous:
tensor = tensor.contiguous()
return tensor
if accelerator.is_main_process:
for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):
device_mesh = sharded_param.device_mesh
full_param = full_param.detach().to(device_mesh.device_type)
dist.broadcast(full_param, src=0, group=device_mesh.get_group())
sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements)
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
full_param,
)
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
sharded_sd[param_name] = sharded_tensor
# We need this else to have a matching `broadcast` for all of the ranks, else we deadlock
else:
for param_name, sharded_param in meta_sharded_sd.items():
device_mesh = sharded_param.device_mesh
full_tensor = torch.empty(sharded_param.size(), device=device_mesh.device_type, dtype=sharded_param.dtype)
dist.broadcast(full_tensor, src=0, group=device_mesh.get_group())
sharded_tensor = distribute_tensor(full_tensor, device_mesh, sharded_param.placements)
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
full_tensor,
)
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
sharded_sd[param_name] = sharded_tensor
# we set `assign=True` because our params are on meta device
model.load_state_dict(sharded_sd, assign=True)
return model