in src/accelerate/accelerator.py [0:0]
def _prepare_fsdp2(self, *args):
# First pass: prepare everything except schedulers (and model, which is prepared separately below)
result = [
self._prepare_one(obj, first_pass=True) if not isinstance(obj, torch.nn.Module) else obj for obj in args
]
# Second pass: prepare schedulers
result = [self._prepare_one(obj) if not isinstance(obj, torch.nn.Module) else obj for obj in result]
# Prepare the model
model_index, model = None, None
for i, obj in enumerate(result):
if isinstance(obj, torch.nn.Module):
model_index, model = i, obj
# Invariant: if we have a model, we also have an optimizer (checked in `prepare`)
if model_index is None:
return tuple(result)
# Needs to be done first, to make sure AC + fully_shard will work as expected
self.state.fsdp_plugin.set_auto_wrap_policy(model)
# Apply AC if needed
if self.state.fsdp_plugin.activation_checkpointing:
model = fsdp2_apply_ac(self, model)
# Apply compile if needed, has to be *after* applying AC
# Copied from: `accelerator.prepare_model` ~ L1804
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
if self.state.dynamo_plugin.use_regional_compilation:
model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs())
else:
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
# Get old params and canonicalize - we cannonicalize to have the mapping easy
old_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=True))
# Swap the optimizer parameters with empty, so `fully_shard` after will not allocate too much memory
for obj in result:
if isinstance(obj, torch.optim.Optimizer):
for param_group in obj.param_groups:
for i, p in enumerate(param_group["params"]):
# We drop a reference to the original param here, so that _move_states_to_device triggers a reallocation
# We reassign the data_ptr to the original param, so that we preserve the mapping to the new ones
param_group["params"][i] = torch.empty_like(p)
param_group["params"][i].data_ptr = p.data_ptr()
self._models.append(model)
# Prepare everything FSDP2 related for the model (except AC)
model = fsdp2_prepare_model(self, model)
# Remove the old model from the list
if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
del self._models[-2]
# Replace the old model with the new one (shouldn't be needed as everything should be in place)
result[model_index] = model
# Get new params and canonicalize
new_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*result))
# Build a map from old to new params
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
# Update the optimizer parameters
for obj in result:
if isinstance(obj, torch.optim.Optimizer):
fsdp2_switch_optimizer_parameters(obj, mapping)
return result