def validation_step()

in src/hyperpod_nemo_adapter/collections/model/sagemaker_base_model.py [0:0]


    def validation_step(self, batch, batch_idx):
        """
        Validation step
        """
        if self._cfg.get("dpo", False):
            prompt_ids, prompt_mask, chosen_ids, chosen_mask, rejected_ids, rejected_mask = (
                self._prepare_dpo_input_batch(batch, batch_idx)
            )
            dpo_params = {
                "model": self,
                "ref_model": self.ref_model,
                "prompt_ids": prompt_ids,
                "prompt_mask": prompt_mask,
                "chosen_ids": chosen_ids,
                "chosen_mask": chosen_mask,
                "rejected_ids": rejected_ids,
                "rejected_mask": rejected_mask,
                "max_length": self._cfg.max_context_width,
                "beta": self._cfg.dpo.get("beta", 0.1),
                "label_smoothing": self._cfg.dpo.get("label_smoothing", 0.0),
                "peft": self.use_peft,
                "bf16": self._cfg.precision == "bf16",
            }
            val_loss = compute_dpo_loss(**dpo_params)
        else:
            input_ids, _, labels = self._prepare_input_batch(batch, batch_idx)
            val_loss = self(
                input_ids=input_ids,
                attention_mask=None,
                labels=labels,
            )["loss"]
        self.val_loss += val_loss.detach()
        return val_loss