in trl/trainer/sft_trainer.py [0:0]
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
# Convert to tensor
input_ids = [torch.tensor(example["input_ids"]) for example in examples]
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
if self.return_position_ids:
if "position_ids" in examples[0]:
position_ids = [torch.tensor(example["position_ids"]) for example in examples]
else:
position_ids = [torch.arange(len(ids)) for ids in input_ids]
if "labels" in examples[0]:
labels = [torch.tensor(example["labels"]) for example in examples]
else:
labels = [torch.tensor(example["input_ids"]) for example in examples]
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = [torch.tensor(example["completion_mask"]) for example in examples]
if "assistant_masks" in examples[0]:
assistant_masks = [torch.tensor(example["assistant_masks"]) for example in examples]
# Pad
output = {}
if self.padding_free:
output["input_ids"] = torch.cat(input_ids, dim=0).unsqueeze(0)
output["attention_mask"] = torch.cat(attention_mask, dim=0).unsqueeze(0)
if self.return_position_ids:
output["position_ids"] = torch.cat(position_ids, dim=0).unsqueeze(0)
output["labels"] = torch.cat(labels, dim=0).unsqueeze(0)
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = torch.cat(completion_mask, dim=0).unsqueeze(0)
output["labels"][completion_mask == 0] = -100
if "assistant_masks" in examples[0]:
assistant_masks = torch.cat(assistant_masks, dim=0).unsqueeze(0)
output["labels"][assistant_masks == 0] = -100
else:
output["input_ids"] = pad(
input_ids,
padding_value=self.pad_token_id,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
)
output["attention_mask"] = pad(
attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.return_position_ids:
output["position_ids"] = pad(
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"] = pad(
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = pad(
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion
if "assistant_masks" in examples[0]:
assistant_masks = pad(
assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][assistant_masks == 0] = -100
return output