in optimum/habana/sentence_transformers/st_gaudi_trainer.py [0:0]
def _wrap_model(self, model, training=True, dataloader=None):
"""
Differs from GaudiTrainer._wrap_model:
- `allow_unused_input=True` was added to `ht.hpu.ModuleCacher()`
"""
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
if self.accelerator.unwrap_model(model) is not model:
return model
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
if not training:
return model
if self.args.parallel_mode == ParallelMode.DISTRIBUTED and self.args.distribution_strategy == "ddp":
kwargs = {}
kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
if self.args.ddp_find_unused_parameters and self.args.gradient_checkpointing:
logger.warning(
"ddp_find_unused_parameters and gradient_checkpointing are both True, which may lead to an error:"
" https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021"
)
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
if self.args.use_habana:
kwargs["gradient_as_bucket_view"] = True
if self.args.ddp_broadcast_buffers is not None:
kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers
self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
if self.args.use_hpu_graphs_for_training:
import habana_frameworks.torch as ht
if _is_peft_model(model):
base_model = model.get_base_model()
ht.hpu.ModuleCacher()(model=base_model, allow_unused_input=True, inplace=True)
else:
ht.hpu.ModuleCacher()(model=model, allow_unused_input=True, inplace=True)
return model