in src/accelerate/utils/fsdp_utils.py [0:0]
def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
"""Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.
Args:
accelerator (`Accelerator`): The accelerator instance
model (`torch.nn.Module`): The model to prepare
Returns:
`torch.nn.Module`: Prepared model
"""
from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard
is_type_fsdp = isinstance(model, FSDPModule) or (
is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule)
)
if is_type_fsdp:
return model
fsdp2_plugin = accelerator.state.fsdp_plugin
fsdp2_plugin.set_auto_wrap_policy(model)
original_sd = model.state_dict()
fsdp2_kwargs = {
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
"offload_policy": fsdp2_plugin.cpu_offload,
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
}
model_has_params4bit = False
for name, param in model.named_parameters():
# this is a temporary fix whereby loading models with bnb params cannot be moved from
# GPU to a meta device due with FSDP2 because torch operations don't return the original class type
# bypassing the move to meta will still cause the VRAM spike, but at least it still will load
if param.__class__.__name__ == "Params4bit":
model_has_params4bit = True
break
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
# For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.emtpy`), `fully_shard` would move it to GPU
# Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike
# We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device
# Also, these buffers aren't getting sharded by default
# We get the FQNs of all non-persistent buffers, to re-register them after
non_persistent_buffer_fqns = get_non_persistent_buffers(model, recurse=True, fqns=True)
original_non_persistent_buffers = copy.deepcopy(
{k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns}
)
# We move the model to meta device, as then sharding happens on meta device
model = model.to(torch.device("meta"))
# We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage
# We assume `transformers` models have a `tie_weights` method if they support it
if hasattr(model, "tie_weights"):
model.tie_weights()
auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
if auto_wrap_policy_func is not None:
# We skip the model itself, as that one is always wrapped
for module in get_module_children_bottom_up(model)[:-1]:
if auto_wrap_policy_func(module) and not isinstance(module, FSDPModule):
fully_shard(module, **fsdp2_kwargs)
if not isinstance(model, FSDPModule):
fully_shard(model, **fsdp2_kwargs)
if fsdp2_plugin.cpu_ram_efficient_loading:
# If `cpu_ram_efficient_loading` is enabled, only rank 0 loads the weights
# Other ranks have an empty model on `meta` device, so we need to distribute the weights properly
fsdp2_load_full_state_dict(accelerator, model, original_sd)
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# We re-register the buffers, as they may not be in the state_dict
for fqn, buffer_tensor in original_non_persistent_buffers.items():
buffer_tensor = buffer_tensor.to(accelerator.device)
if "." in fqn:
parent_fqn, local_buffer_name = fqn.rsplit(".", 1)
parent_module = model.get_submodule(parent_fqn)
else:
local_buffer_name = fqn
parent_module = model
parent_module.register_buffer(local_buffer_name, buffer_tensor, persistent=False)
# We need to tie the weights again, as call to `load_full_state_dict` breaks the tie
# Needs to be called both here and above
# removing this call makes the have slightly different loss
# removing the call above leads to extra memory usage as explained in the comment above
if hasattr(model, "tie_weights"):
model.tie_weights()
# There is no `dtype` attribution for nn.Module
# Set it to None if it doesn't exist and do the upcast always
model_dtype = getattr(model, "dtype", None)
if accelerator.mixed_precision != "no" and (model_dtype is None or model_dtype != torch.float32):
# We upcast the model according to `deepspeed`'s implementation
# More info about this can be found in `accelerator.py:prepare_model`s FSDP1 section
model = model.to(torch.float32)
if accelerator.is_main_process:
# TODO(siro1): Add a warning for each parameter that was upcasted
warnings.warn(
"FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints."
)
return model