def _prepare_fsdp2()

in src/accelerate/accelerator.py [0:0]


    def _prepare_fsdp2(self, *args):
        # First pass: prepare everything except schedulers (and model, which is prepared separately below)
        result = [
            self._prepare_one(obj, first_pass=True) if not isinstance(obj, torch.nn.Module) else obj for obj in args
        ]

        # Second pass: prepare schedulers
        result = [self._prepare_one(obj) if not isinstance(obj, torch.nn.Module) else obj for obj in result]

        # Prepare the model
        model_index, model = None, None
        for i, obj in enumerate(result):
            if isinstance(obj, torch.nn.Module):
                model_index, model = i, obj

        # Invariant: if we have a model, we also have an optimizer (checked in `prepare`)
        if model_index is None:
            return tuple(result)

        # Needs to be done first, to make sure AC + fully_shard will work as expected
        self.state.fsdp_plugin.set_auto_wrap_policy(model)

        # Apply AC if needed
        if self.state.fsdp_plugin.activation_checkpointing:
            model = fsdp2_apply_ac(self, model)

        # Apply compile if needed, has to be *after* applying AC
        # Copied from: `accelerator.prepare_model` ~ L1804
        if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
            if self.state.dynamo_plugin.use_regional_compilation:
                model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs())
            else:
                model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())

        # Get old params and canonicalize - we cannonicalize to have the mapping easy
        old_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=True))

        # Swap the optimizer parameters with empty, so `fully_shard` after will not allocate too much memory
        for obj in result:
            if isinstance(obj, torch.optim.Optimizer):
                for param_group in obj.param_groups:
                    for i, p in enumerate(param_group["params"]):
                        # We drop a reference to the original param here, so that _move_states_to_device triggers a reallocation
                        # We reassign the data_ptr to the original param, so that we preserve the mapping to the new ones
                        param_group["params"][i] = torch.empty_like(p)
                        param_group["params"][i].data_ptr = p.data_ptr()

        self._models.append(model)

        # Prepare everything FSDP2 related for the model (except AC)
        model = fsdp2_prepare_model(self, model)

        # Remove the old model from the list
        if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
            del self._models[-2]

        # Replace the old model with the new one (shouldn't be needed as everything should be in place)
        result[model_index] = model

        # Get new params and canonicalize
        new_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*result))
        # Build a map from old to new params
        mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
        # Update the optimizer parameters
        for obj in result:
            if isinstance(obj, torch.optim.Optimizer):
                fsdp2_switch_optimizer_parameters(obj, mapping)

        return result