def _wrap_model()

in optimum/habana/sentence_transformers/st_gaudi_trainer.py [0:0]


    def _wrap_model(self, model, training=True, dataloader=None):
        """
        Differs from GaudiTrainer._wrap_model:
        - `allow_unused_input=True` was added to `ht.hpu.ModuleCacher()`
        """
        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
        if self.accelerator.unwrap_model(model) is not model:
            return model

        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
        if not training:
            return model

        if self.args.parallel_mode == ParallelMode.DISTRIBUTED and self.args.distribution_strategy == "ddp":
            kwargs = {}

            kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
            if self.args.ddp_find_unused_parameters and self.args.gradient_checkpointing:
                logger.warning(
                    "ddp_find_unused_parameters and gradient_checkpointing are both True, which may lead to an error:"
                    " https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021"
                )
            kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb

            if self.args.use_habana:
                kwargs["gradient_as_bucket_view"] = True

            if self.args.ddp_broadcast_buffers is not None:
                kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers

            self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)

        if self.args.use_hpu_graphs_for_training:
            import habana_frameworks.torch as ht

            if _is_peft_model(model):
                base_model = model.get_base_model()
                ht.hpu.ModuleCacher()(model=base_model, allow_unused_input=True, inplace=True)
            else:
                ht.hpu.ModuleCacher()(model=model, allow_unused_input=True, inplace=True)

        return model