def prepare_model()

in src/accelerate/accelerator.py [0:0]


    def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False):
        """
        Prepares a PyTorch model for training in any distributed setup. It is recommended to use
        [`Accelerator.prepare`] instead.

        Args:
            model (`torch.nn.Module`):
                A PyTorch model to prepare. You don't need to prepare a model if it is used only for inference without
                any kind of mixed precision
            device_placement (`bool`, *optional*):
                Whether or not to place the model on the proper device. Will default to `self.device_placement`.
            evaluation_mode (`bool`, *optional*, defaults to `False`):
                Whether or not to set the model for evaluation only, by just applying mixed precision and
                `torch.compile` (if configured in the `Accelerator` object).

        Example:

        ```python
        >>> from accelerate import Accelerator

        >>> accelerator = Accelerator()
        >>> # Assume a model is defined
        >>> model = accelerator.prepare_model(model)
        ```
        """
        if device_placement is None:
            device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP

        self._models.append(model)

        # TODO: Look at enabling native TP training directly with a proper config
        if (
            self.verify_device_map(model)
            and self.distributed_type != DistributedType.NO
            and os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true"
        ):
            raise ValueError(
                "You can't train a model that has been loaded with `device_map='auto'` in any distributed mode."
                " Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`."
            )

        if self.native_amp:
            model._original_forward = model.forward
            autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
            # NOTE: MS-AMP adds `__func__` already to `model.forward`, so we should always use `model.forward`
            if self.fp8_backend == "MSAMP" or not hasattr(model.forward, "__func__"):
                model_forward_func = model.forward
                model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func))
            else:
                model_forward_func = model.forward.__func__
                new_forward = autocast_context(model_forward_func)
                model.forward = MethodType(new_forward, model)
                model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)

        # We prepare TE after, allowing for bf16 autocast to happen first
        if self.fp8_backend == "TE" and not self.delayed_fp8_autocast:
            model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)

        if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
            model, "hf_device_map", False
        ):
            model_devices = set(model.hf_device_map.values())
            if len(model_devices) > 1 and self.distributed_type != DistributedType.NO:
                raise ValueError(
                    "You can't train a model that has been loaded in 8-bit or 4-bit precision on multiple devices in any distributed mode."
                    " In order to use 8-bit or 4-bit models that have been loaded across multiple GPUs the solution is to use Naive Pipeline Parallelism."
                    " Therefore you should not specify that you are under any distributed regime in your accelerate config."
                )
            elif len(model_devices) == 1:
                current_device = list(model_devices)[0]
                if isinstance(current_device, torch.device):
                    current_device_index = current_device.index
                elif isinstance(current_device, str):
                    current_device_index = torch.device(current_device).index
                else:
                    current_device_index = current_device

                if self.device.type == "cpu" and is_bitsandbytes_multi_backend_available():
                    # bnb with multi-backend supports CPU which don't need to check index.
                    pass
                elif torch.device(current_device_index) != self.device:
                    # if on the first device (GPU 0) we don't care
                    if (self.device.index is not None) or (current_device_index != 0):
                        raise ValueError(
                            "You can't train a model that has been loaded in 8-bit or 4-bit precision on a different device than the one "
                            "you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device()}` or `device_map={'':torch.xpu.current_device()}`"
                        )
            if (
                ("cpu" in model_devices and not is_bitsandbytes_multi_backend_available())
                or ("cpu" in model_devices and is_xpu_available())
                or "disk" in model_devices
            ):
                raise ValueError(
                    "You can't train a model that has been loaded in 8-bit or 4-bit precision with CPU or disk offload. "
                    "If you want train the 8-bit or 4-bit model in CPU, please install bitsandbytes with multi-backend, see https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend"
                )
        elif device_placement and not self.verify_device_map(model):
            model = model.to(self.device)
        if not evaluation_mode:
            if self.distributed_type in (
                DistributedType.MULTI_GPU,
                DistributedType.MULTI_MLU,
                DistributedType.MULTI_SDAA,
                DistributedType.MULTI_MUSA,
                DistributedType.MULTI_NPU,
                DistributedType.MULTI_XPU,
                DistributedType.MULTI_HPU,
            ):
                if model_has_dtensor(model):
                    raise ValueError(
                        "Your model contains `DTensor` parameters, which is incompatible with DDP. Maybe you loaded your model with `device_map='auto'`? Specify `device_map='cuda'` or 'cpu' instead."
                    )
                if any(p.requires_grad for p in model.parameters()):
                    kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
                    # TODO: Look at enabling native TP training directly with a proper config
                    if os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true":
                        if self.device.type == "hpu":
                            device_ids, output_device = [self.device.index], self.device.index
                        else:
                            device_ids, output_device = [self.local_process_index], self.local_process_index
                    else:
                        device_ids, output_device = None, None

                    model = torch.nn.parallel.DistributedDataParallel(
                        model, device_ids=device_ids, output_device=output_device, **kwargs
                    )
                    if self.ddp_handler is not None:
                        self.ddp_handler.register_comm_hook(model)
            elif self.distributed_type == DistributedType.TP:
                if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
                    raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
                if not hasattr(model, "tp_size"):
                    raise NotImplementedError(
                        "Model should undergo tensor parallel before passing it to accelerate."
                        "You can use .from_pretrained(..., tp_plan='auto') if the model supports"
                    )
                if model.tp_size != self.state.torch_tp_plugin.tp_size:
                    raise ValueError(
                        f"tp_size in the plugin {self.state.torch_tp_plugin.tp_size} should be same as model's tp size {model.tp_size}"
                    )
            elif self.is_fsdp2:
                raise ValueError(
                    "FSDP2 preparation should be done via `accelerate.prepare()`, as it requires a model and an optimizer."
                )

            elif self.distributed_type == DistributedType.FSDP:
                # We need to fix the optimizer *before* sharding the model
                from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

                # Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
                # don't wrap it again
                # In case the model is already compiled using PyTorch 2.0 and the wrapped model in it
                # is a FSDP model, don't wrap it again
                is_type_fsdp = isinstance(model, FSDP) or (
                    is_compiled_module(model) and isinstance(model._orig_mod, FSDP)
                )

                if not is_type_fsdp:
                    self.state.fsdp_plugin.set_auto_wrap_policy(model)
                    fsdp_plugin = self.state.fsdp_plugin

                    # need to ensure that params are re-tied after running
                    # param_init_fn
                    fsdp_plugin.param_init_fn = ensure_weights_retied(
                        fsdp_plugin.param_init_fn,
                        model,
                        self.device,
                    )

                    kwargs = {
                        # We fallback to reshard_after_forward if sharding_strategy is not set.
                        # We prerfer sharding_strategy to not break the behavior of the existing code.
                        # Deprecation warning has already been issued in `utils.dataclasses.py`
                        "sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward,
                        "cpu_offload": fsdp_plugin.cpu_offload,
                        "auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
                        "mixed_precision": fsdp_plugin.mixed_precision_policy,
                        "sync_module_states": fsdp_plugin.sync_module_states,
                        "backward_prefetch": fsdp_plugin.backward_prefetch,
                        "forward_prefetch": fsdp_plugin.forward_prefetch,
                        "use_orig_params": fsdp_plugin.use_orig_params,
                        "param_init_fn": fsdp_plugin.param_init_fn,
                        "ignored_modules": fsdp_plugin.ignored_modules,
                        "limit_all_gathers": fsdp_plugin.limit_all_gathers,
                        "device_id": self.device,
                    }
                    model = FSDP(model, **kwargs)
                    if fsdp_plugin.activation_checkpointing:
                        from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
                            CheckpointImpl,
                            apply_activation_checkpointing,
                            checkpoint_wrapper,
                        )

                        apply_activation_checkpointing(
                            model,
                            checkpoint_wrapper_fn=functools.partial(
                                checkpoint_wrapper,
                                checkpoint_impl=CheckpointImpl.NO_REENTRANT,
                            ),
                            auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
                        )

                # In the event the model had been loaded in low precision, but
                # mixed precision had also been activated, then we follow DeepSpeed's
                # strategy to hold the parameters in full precision.
                # - assume that trainer.args.bf16 and trainer.args.fp16 are already checked against
                #   fsdp_plugin.mixed_precision_policy.
                # - NOTE: we do not check the mixed_precision attribute on the FSDP root wrapper.
                #   * this attribute will always set by init_utils.init_core_state so its always not None.
                #   * mixed_precision.param_dtype only regards _fwd_bwd_param_dtype
                #   * if model is loaded in 16bit, and even if mixed_precision.param_dtype is None,
                #     we still want to upcast the flat_param.
                if self.mixed_precision != "no":  # if mixed precision is set
                    upcasted_log = []
                    for module in FSDP.fsdp_modules(model):
                        # Referencing DeepSpeed Zero3
                        # - in Init, params are converted to 16bit while partitioning.
                        # - in accelerator.prepare, deepspeed.initialize is called to:
                        #   * creates the DeepSpeedEngine.
                        #   * since zero_optimization() is True , calls engine._configure_zero_optimizer.
                        #
                        # Inside the DeepSpeed Zero3 optimizer configuration, which initializes
                        # DeepSpeedZeroOptimizer_Stage3, during which:
                        #   * trainable_param_groups are obtained from the attached optimizer
                        #     (already partitioned in 16bit).
                        #   * then _setup_for_real_optimizer -> _create_fp32_partitions
                        #     which performs the fp32 upcasting.

                        # To mimic DeepSeepds's casting in FSDP, we look at the (single) FlatParameter held
                        # within an FSDP wrapper. This FlatParameter will be seen by the optimizer.
                        #  - even though there is a torch.device('meta') guard below, we
                        #    expect _init_utils._init_param_handle_from_module to already
                        #    sync the parameter.

                        if not module._has_params:
                            continue  # skip if FSDP module not managing parameters
                        param = module._flat_param
                        if (
                            param.dtype != torch.float32
                            and param.device != torch.device("meta")
                            and param.requires_grad
                        ):
                            # keep log of names_params that was upcasted
                            # NOTE: resorted to this because warnings.simplefilter("once") is somehow not working
                            name_param_log = (module.module.__class__.__name__, ", ".join(module._flat_param._fqns))
                            if name_param_log not in upcasted_log:
                                upcasted_log.append(name_param_log)

                            # this works because of FSDP's _runtime_utils.lazy_init.
                            # Have to be careful not to call anything before this that
                            # triggers lazy_init (e.g., _is_fsdp_root).
                            param.data = param.data.to(torch.float32)  # upcasting
                            module._handle._orig_param_dtype = torch.float32  # update

                    # report the warnings
                    # some messages can be quite repetitive, especially when reporting about layers that have identical architecture.
                    if self.is_main_process:
                        for name_log, param_log in upcasted_log:
                            warnings.warn(
                                f"Upcasted low precision parameters in {name_log} because mixed precision turned on in FSDP. "
                                f"Affects: {param_log}."
                            )

                        if len(upcasted_log) > 0:
                            warnings.warn(
                                "FSDP upcast of low precision parameters may affect the precision of model checkpoints."
                            )

                # if the previous and current models are same, delete the previous one
                if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
                    del self._models[-2]
                self._models[-1] = model
            elif self.distributed_type == DistributedType.MULTI_CPU:
                kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
                model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
                if self.ddp_handler is not None:
                    self.ddp_handler.register_comm_hook(model)
            elif self.distributed_type == DistributedType.XLA and self.state.fork_launched:
                model = xmp.MpModelWrapper(model).to(self.device)
        # Now we can apply the FP8 autocast
        if self.fp8_backend == "TE" and self.delayed_fp8_autocast:
            model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
        # torch.compile should be called last and only if the model isn't already compiled
        if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
            if self.state.dynamo_plugin.use_regional_compilation:
                model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs())
            else:
                model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
        return model