def _prepare_dpo_input_batch()

in src/hyperpod_nemo_adapter/collections/model/sagemaker_base_model.py [0:0]


    def _prepare_dpo_input_batch(self, batch, batch_idx):
        """
        Parse input batch, pre-process for DPO context parallel
        """
        prompt_ids, prompt_mask, chosen_ids, chosen_mask, rejected_ids, rejected_mask = (
            self.trainer.datamodule.get_batch(batch)
        )
        self.batch_num_sequences = prompt_ids.shape[0]
        if self._cfg.get("context_parallel_degree", 1) > 1:
            # Apply context parallel processing to all tensors
            prompt_ids, prompt_mask, chosen_ids, chosen_mask, rejected_ids, rejected_mask = get_batch_for_cp_rank(
                (prompt_ids, prompt_mask, chosen_ids, chosen_mask, rejected_ids, rejected_mask)
            )
        if batch_idx == 0:
            # checking only on batch 0 to reduce checks during runtime
            chosen_width = prompt_ids.shape[1] + chosen_ids.shape[1]
            rejected_width = prompt_ids.shape[1] + rejected_ids.shape[1]
            width_over_degree = self._cfg.max_context_width // self._cfg.get("context_parallel_degree", 1)
            if chosen_width != width_over_degree or rejected_width != width_over_degree:
                _logger.warning(
                    f"Warning: input data passed {prompt_ids.shape}, {chosen_ids.shape}, {rejected_ids.shape} does not respect max_context_width set. If context parallelism is enabled,",
                    f"Completion input_ids sequence length == (model.max_context_width / model.context_parallel_degree) ",
                )
        return prompt_ids, prompt_mask, chosen_ids, chosen_mask, rejected_ids, rejected_mask