in src/hyperpod_nemo_adapter/collections/model/sagemaker_base_model.py [0:0]
def _training_step_dpo(self, batch, batch_idx, *a, **kw):
prompt_ids, prompt_mask, chosen_ids, chosen_mask, rejected_ids, rejected_mask = self._prepare_dpo_input_batch(
batch, batch_idx
)
dpo_params = {
"model": self,
"ref_model": self.ref_model,
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"chosen_ids": chosen_ids,
"chosen_mask": chosen_mask,
"rejected_ids": rejected_ids,
"rejected_mask": rejected_mask,
"max_length": self._cfg.max_context_width,
"beta": self._cfg.dpo.get("beta", 0.1),
"label_smoothing": self._cfg.dpo.get("label_smoothing", 0.0),
"peft": self.use_peft,
"bf16": self._cfg.precision == "bf16",
}
if self.use_smp_model and self._cfg.fp8:
fp8 = self._cfg.fp8
fp8_recipe = self.fp8_recipe
fp8_group = tsm.state.world_process_group
with transformer_engine.pytorch.fp8_autocast(
enabled=fp8,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group,
):
loss = compute_dpo_loss(
*a,
**dpo_params,
**kw,
)
else:
loss = compute_dpo_loss(
*a,
**dpo_params,
**kw,
)
return loss