in optimum/habana/transformers/trainer.py [0:0]
def _wrap_model(self, model, training=True, dataloader=None):
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
if self.accelerator.unwrap_model(model, keep_torch_compile=False) is not model:
return model
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
if not training:
return model
if self.args.parallel_mode == ParallelMode.DISTRIBUTED and self.args.distribution_strategy == "ddp":
kwargs = {}
if self.args.ddp_find_unused_parameters is not None:
kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
if self.args.ddp_find_unused_parameters and self.args.gradient_checkpointing:
logger.warning(
"ddp_find_unused_parameters and gradient_checkpointing are both True, which may lead to an error:"
" https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021"
)
elif isinstance(model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
else:
kwargs["find_unused_parameters"] = True
if self.args.ddp_bucket_cap_mb is not None:
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
if self.args.use_habana:
kwargs["gradient_as_bucket_view"] = True
if self.args.ddp_broadcast_buffers is not None:
kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers
self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
if self.args.use_hpu_graphs_for_training:
import habana_frameworks.torch as ht
ht.hpu.ModuleCacher()(model=model, inplace=True)
return model