in src/accelerate/accelerator.py [0:0]
def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False):
"""
Prepares a PyTorch model for training in any distributed setup. It is recommended to use
[`Accelerator.prepare`] instead.
Args:
model (`torch.nn.Module`):
A PyTorch model to prepare. You don't need to prepare a model if it is used only for inference without
any kind of mixed precision
device_placement (`bool`, *optional*):
Whether or not to place the model on the proper device. Will default to `self.device_placement`.
evaluation_mode (`bool`, *optional*, defaults to `False`):
Whether or not to set the model for evaluation only, by just applying mixed precision and
`torch.compile` (if configured in the `Accelerator` object).
Example:
```python
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> # Assume a model is defined
>>> model = accelerator.prepare_model(model)
```
"""
if device_placement is None:
device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP
self._models.append(model)
# TODO: Look at enabling native TP training directly with a proper config
if (
self.verify_device_map(model)
and self.distributed_type != DistributedType.NO
and os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true"
):
raise ValueError(
"You can't train a model that has been loaded with `device_map='auto'` in any distributed mode."
" Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`."
)
if self.native_amp:
model._original_forward = model.forward
autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
# NOTE: MS-AMP adds `__func__` already to `model.forward`, so we should always use `model.forward`
if self.fp8_backend == "MSAMP" or not hasattr(model.forward, "__func__"):
model_forward_func = model.forward
model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func))
else:
model_forward_func = model.forward.__func__
new_forward = autocast_context(model_forward_func)
model.forward = MethodType(new_forward, model)
model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)
# We prepare TE after, allowing for bf16 autocast to happen first
if self.fp8_backend == "TE" and not self.delayed_fp8_autocast:
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
model, "hf_device_map", False
):
model_devices = set(model.hf_device_map.values())
if len(model_devices) > 1 and self.distributed_type != DistributedType.NO:
raise ValueError(
"You can't train a model that has been loaded in 8-bit or 4-bit precision on multiple devices in any distributed mode."
" In order to use 8-bit or 4-bit models that have been loaded across multiple GPUs the solution is to use Naive Pipeline Parallelism."
" Therefore you should not specify that you are under any distributed regime in your accelerate config."
)
elif len(model_devices) == 1:
current_device = list(model_devices)[0]
if isinstance(current_device, torch.device):
current_device_index = current_device.index
elif isinstance(current_device, str):
current_device_index = torch.device(current_device).index
else:
current_device_index = current_device
if self.device.type == "cpu" and is_bitsandbytes_multi_backend_available():
# bnb with multi-backend supports CPU which don't need to check index.
pass
elif torch.device(current_device_index) != self.device:
# if on the first device (GPU 0) we don't care
if (self.device.index is not None) or (current_device_index != 0):
raise ValueError(
"You can't train a model that has been loaded in 8-bit or 4-bit precision on a different device than the one "
"you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device()}` or `device_map={'':torch.xpu.current_device()}`"
)
if (
("cpu" in model_devices and not is_bitsandbytes_multi_backend_available())
or ("cpu" in model_devices and is_xpu_available())
or "disk" in model_devices
):
raise ValueError(
"You can't train a model that has been loaded in 8-bit or 4-bit precision with CPU or disk offload. "
"If you want train the 8-bit or 4-bit model in CPU, please install bitsandbytes with multi-backend, see https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend"
)
elif device_placement and not self.verify_device_map(model):
model = model.to(self.device)
if not evaluation_mode:
if self.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_MLU,
DistributedType.MULTI_SDAA,
DistributedType.MULTI_MUSA,
DistributedType.MULTI_NPU,
DistributedType.MULTI_XPU,
DistributedType.MULTI_HPU,
):
if model_has_dtensor(model):
raise ValueError(
"Your model contains `DTensor` parameters, which is incompatible with DDP. Maybe you loaded your model with `device_map='auto'`? Specify `device_map='cuda'` or 'cpu' instead."
)
if any(p.requires_grad for p in model.parameters()):
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
# TODO: Look at enabling native TP training directly with a proper config
if os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true":
if self.device.type == "hpu":
device_ids, output_device = [self.device.index], self.device.index
else:
device_ids, output_device = [self.local_process_index], self.local_process_index
else:
device_ids, output_device = None, None
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=device_ids, output_device=output_device, **kwargs
)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
if not hasattr(model, "tp_size"):
raise NotImplementedError(
"Model should undergo tensor parallel before passing it to accelerate."
"You can use .from_pretrained(..., tp_plan='auto') if the model supports"
)
if model.tp_size != self.state.torch_tp_plugin.tp_size:
raise ValueError(
f"tp_size in the plugin {self.state.torch_tp_plugin.tp_size} should be same as model's tp size {model.tp_size}"
)
elif self.is_fsdp2:
raise ValueError(
"FSDP2 preparation should be done via `accelerate.prepare()`, as it requires a model and an optimizer."
)
elif self.distributed_type == DistributedType.FSDP:
# We need to fix the optimizer *before* sharding the model
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
# don't wrap it again
# In case the model is already compiled using PyTorch 2.0 and the wrapped model in it
# is a FSDP model, don't wrap it again
is_type_fsdp = isinstance(model, FSDP) or (
is_compiled_module(model) and isinstance(model._orig_mod, FSDP)
)
if not is_type_fsdp:
self.state.fsdp_plugin.set_auto_wrap_policy(model)
fsdp_plugin = self.state.fsdp_plugin
# need to ensure that params are re-tied after running
# param_init_fn
fsdp_plugin.param_init_fn = ensure_weights_retied(
fsdp_plugin.param_init_fn,
model,
self.device,
)
kwargs = {
# We fallback to reshard_after_forward if sharding_strategy is not set.
# We prerfer sharding_strategy to not break the behavior of the existing code.
# Deprecation warning has already been issued in `utils.dataclasses.py`
"sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward,
"cpu_offload": fsdp_plugin.cpu_offload,
"auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
"mixed_precision": fsdp_plugin.mixed_precision_policy,
"sync_module_states": fsdp_plugin.sync_module_states,
"backward_prefetch": fsdp_plugin.backward_prefetch,
"forward_prefetch": fsdp_plugin.forward_prefetch,
"use_orig_params": fsdp_plugin.use_orig_params,
"param_init_fn": fsdp_plugin.param_init_fn,
"ignored_modules": fsdp_plugin.ignored_modules,
"limit_all_gathers": fsdp_plugin.limit_all_gathers,
"device_id": self.device,
}
model = FSDP(model, **kwargs)
if fsdp_plugin.activation_checkpointing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
),
auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
)
# In the event the model had been loaded in low precision, but
# mixed precision had also been activated, then we follow DeepSpeed's
# strategy to hold the parameters in full precision.
# - assume that trainer.args.bf16 and trainer.args.fp16 are already checked against
# fsdp_plugin.mixed_precision_policy.
# - NOTE: we do not check the mixed_precision attribute on the FSDP root wrapper.
# * this attribute will always set by init_utils.init_core_state so its always not None.
# * mixed_precision.param_dtype only regards _fwd_bwd_param_dtype
# * if model is loaded in 16bit, and even if mixed_precision.param_dtype is None,
# we still want to upcast the flat_param.
if self.mixed_precision != "no": # if mixed precision is set
upcasted_log = []
for module in FSDP.fsdp_modules(model):
# Referencing DeepSpeed Zero3
# - in Init, params are converted to 16bit while partitioning.
# - in accelerator.prepare, deepspeed.initialize is called to:
# * creates the DeepSpeedEngine.
# * since zero_optimization() is True , calls engine._configure_zero_optimizer.
#
# Inside the DeepSpeed Zero3 optimizer configuration, which initializes
# DeepSpeedZeroOptimizer_Stage3, during which:
# * trainable_param_groups are obtained from the attached optimizer
# (already partitioned in 16bit).
# * then _setup_for_real_optimizer -> _create_fp32_partitions
# which performs the fp32 upcasting.
# To mimic DeepSeepds's casting in FSDP, we look at the (single) FlatParameter held
# within an FSDP wrapper. This FlatParameter will be seen by the optimizer.
# - even though there is a torch.device('meta') guard below, we
# expect _init_utils._init_param_handle_from_module to already
# sync the parameter.
if not module._has_params:
continue # skip if FSDP module not managing parameters
param = module._flat_param
if (
param.dtype != torch.float32
and param.device != torch.device("meta")
and param.requires_grad
):
# keep log of names_params that was upcasted
# NOTE: resorted to this because warnings.simplefilter("once") is somehow not working
name_param_log = (module.module.__class__.__name__, ", ".join(module._flat_param._fqns))
if name_param_log not in upcasted_log:
upcasted_log.append(name_param_log)
# this works because of FSDP's _runtime_utils.lazy_init.
# Have to be careful not to call anything before this that
# triggers lazy_init (e.g., _is_fsdp_root).
param.data = param.data.to(torch.float32) # upcasting
module._handle._orig_param_dtype = torch.float32 # update
# report the warnings
# some messages can be quite repetitive, especially when reporting about layers that have identical architecture.
if self.is_main_process:
for name_log, param_log in upcasted_log:
warnings.warn(
f"Upcasted low precision parameters in {name_log} because mixed precision turned on in FSDP. "
f"Affects: {param_log}."
)
if len(upcasted_log) > 0:
warnings.warn(
"FSDP upcast of low precision parameters may affect the precision of model checkpoints."
)
# if the previous and current models are same, delete the previous one
if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
del self._models[-2]
self._models[-1] = model
elif self.distributed_type == DistributedType.MULTI_CPU:
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.XLA and self.state.fork_launched:
model = xmp.MpModelWrapper(model).to(self.device)
# Now we can apply the FP8 autocast
if self.fp8_backend == "TE" and self.delayed_fp8_autocast:
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
# torch.compile should be called last and only if the model isn't already compiled
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
if self.state.dynamo_plugin.use_regional_compilation:
model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs())
else:
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
return model