in src/hyperpod_nemo_adapter/collections/model/sagemaker_base_model.py [0:0]
def _prepare_dpo_input_batch(self, batch, batch_idx):
"""
Parse input batch, pre-process for DPO context parallel
"""
prompt_ids, prompt_mask, chosen_ids, chosen_mask, rejected_ids, rejected_mask = (
self.trainer.datamodule.get_batch(batch)
)
self.batch_num_sequences = prompt_ids.shape[0]
if self._cfg.get("context_parallel_degree", 1) > 1:
# Apply context parallel processing to all tensors
prompt_ids, prompt_mask, chosen_ids, chosen_mask, rejected_ids, rejected_mask = get_batch_for_cp_rank(
(prompt_ids, prompt_mask, chosen_ids, chosen_mask, rejected_ids, rejected_mask)
)
if batch_idx == 0:
# checking only on batch 0 to reduce checks during runtime
chosen_width = prompt_ids.shape[1] + chosen_ids.shape[1]
rejected_width = prompt_ids.shape[1] + rejected_ids.shape[1]
width_over_degree = self._cfg.max_context_width // self._cfg.get("context_parallel_degree", 1)
if chosen_width != width_over_degree or rejected_width != width_over_degree:
_logger.warning(
f"Warning: input data passed {prompt_ids.shape}, {chosen_ids.shape}, {rejected_ids.shape} does not respect max_context_width set. If context parallelism is enabled,",
f"Completion input_ids sequence length == (model.max_context_width / model.context_parallel_degree) ",
)
return prompt_ids, prompt_mask, chosen_ids, chosen_mask, rejected_ids, rejected_mask