def prepare_model_inputs()

in optimum/habana/trl/trainer/ppo_trainer.py [0:0]


    def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor):
        """
        Copied from PPOTrainer.prepare_model_inputs: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L949
        The only differences are:
        - add padding to model inputs for static shape support in forward
        """
        if self.is_encoder_decoder:
            input_data = self.data_collator(
                [{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries]
            ).to(self.current_device)

            decoder_inputs = self.data_collator(
                [{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses]
            ).to(self.current_device)

            input_data["decoder_input_ids"] = decoder_inputs["input_ids"]
            input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"]
        else:
            input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
            input_data = self.data_collator(
                [{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids]
            ).to(self.current_device)

        if self.config.pad_for_acceleration:
            input_data["input_ids"] = torch.nn.functional.pad(
                input_data["input_ids"],
                (0, self.config.pad_max_len - input_data["input_ids"].shape[1]),
                value=self.tokenizer.pad_token_id,
            )
            input_data["attention_mask"] = torch.nn.functional.pad(
                input_data["attention_mask"],
                (
                    0,
                    self.config.pad_max_len - input_data["attention_mask"].shape[1],
                ),
                value=0,
            )
            if self.is_encoder_decoder:
                input_data["decoder_input_ids"] = torch.nn.functional.pad(
                    input_data["decoder_input_ids"],
                    (
                        0,
                        self.config.pad_max_len - input_data["decoder_input_ids"].shape[1],
                    ),
                    value=self.tokenizer.pad_token_id,
                )
                input_data["decoder_attention_mask"] = torch.nn.functional.pad(
                    input_data["decoder_attention_mask"],
                    (
                        0,
                        self.config.pad_max_len - input_data["decoder_attention_mask"].shape[1],
                    ),
                    value=0,
                )

        input_data.pop("labels", None)  # we don't want to compute LM losses
        return input_data