def _training_step_dpo()

in src/hyperpod_nemo_adapter/collections/model/sagemaker_base_model.py [0:0]


    def _training_step_dpo(self, batch, batch_idx, *a, **kw):
        prompt_ids, prompt_mask, chosen_ids, chosen_mask, rejected_ids, rejected_mask = self._prepare_dpo_input_batch(
            batch, batch_idx
        )
        dpo_params = {
            "model": self,
            "ref_model": self.ref_model,
            "prompt_ids": prompt_ids,
            "prompt_mask": prompt_mask,
            "chosen_ids": chosen_ids,
            "chosen_mask": chosen_mask,
            "rejected_ids": rejected_ids,
            "rejected_mask": rejected_mask,
            "max_length": self._cfg.max_context_width,
            "beta": self._cfg.dpo.get("beta", 0.1),
            "label_smoothing": self._cfg.dpo.get("label_smoothing", 0.0),
            "peft": self.use_peft,
            "bf16": self._cfg.precision == "bf16",
        }
        if self.use_smp_model and self._cfg.fp8:
            fp8 = self._cfg.fp8
            fp8_recipe = self.fp8_recipe
            fp8_group = tsm.state.world_process_group
            with transformer_engine.pytorch.fp8_autocast(
                enabled=fp8,
                fp8_recipe=fp8_recipe,
                fp8_group=fp8_group,
            ):
                loss = compute_dpo_loss(
                    *a,
                    **dpo_params,
                    **kw,
                )
        else:
            loss = compute_dpo_loss(
                *a,
                **dpo_params,
                **kw,
            )
        return loss