in trl/trainer/iterative_sft_trainer.py [0:0]
def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor):
if attention_mask is None:
attention_mask = [torch.ones_like(ids) for ids in input_ids]
if self.is_encoder_decoder:
input_data = self.data_collator(
[
{"input_ids": ids, "attention_mask": att, "labels": lab}
for ids, att, lab in zip(input_ids, attention_mask, labels)
]
).to(self.model.device)
input_data.pop("decoder_input_ids", None) # This is directly computed inside the model
input_data["labels"][input_data["labels"] == self.processing_class.pad_token_id] = -100
else:
input_data = self.data_collator(
[{"input_ids": ids, "attention_mask": att} for ids, att in zip(input_ids, attention_mask)]
).to(self.model.device)
# truncate in case the user has provided input_ids, attention_mask and labels
if self.max_length is not None:
if self.truncation_mode == "keep_start":
input_data = {k: v[: self.max_length] for k, v in input_data.items()}
elif self.truncation_mode == "keep_end":
input_data = {k: v[-self.max_length :] for k, v in input_data.items()}
else:
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
return input_data