in src/hyperpod_nemo_adapter/collections/model/sagemaker_base_model.py [0:0]
def validation_step(self, batch, batch_idx):
"""
Validation step
"""
if self._cfg.get("dpo", False):
prompt_ids, prompt_mask, chosen_ids, chosen_mask, rejected_ids, rejected_mask = (
self._prepare_dpo_input_batch(batch, batch_idx)
)
dpo_params = {
"model": self,
"ref_model": self.ref_model,
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"chosen_ids": chosen_ids,
"chosen_mask": chosen_mask,
"rejected_ids": rejected_ids,
"rejected_mask": rejected_mask,
"max_length": self._cfg.max_context_width,
"beta": self._cfg.dpo.get("beta", 0.1),
"label_smoothing": self._cfg.dpo.get("label_smoothing", 0.0),
"peft": self.use_peft,
"bf16": self._cfg.precision == "bf16",
}
val_loss = compute_dpo_loss(**dpo_params)
else:
input_ids, _, labels = self._prepare_input_batch(batch, batch_idx)
val_loss = self(
input_ids=input_ids,
attention_mask=None,
labels=labels,
)["loss"]
self.val_loss += val_loss.detach()
return val_loss