in trl/trainer/utils.py [0:0]
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
# first, pad everything to the same length
padded_batch = {}
for k in features[0].keys():
if k.endswith(("_input_ids", "_attention_mask", "_labels", "_pixel_values")):
if self.is_encoder_decoder:
to_pad = [torch.LongTensor(ex[k]) for ex in features]
if (k.startswith("prompt")) and (k.endswith("input_ids")):
if self.pad_token_id is None:
raise ValueError(
"Padding is enabled, but the tokenizer is not configured with a padding token."
" Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
" before calling the trainer."
)
padding_value = self.pad_token_id
elif k.endswith("_attention_mask"):
padding_value = 0
elif k.startswith(("chosen", "rejected", "completion")) or ("decoder" in k):
padding_value = self.label_pad_token_id
else:
raise ValueError(f"Unexpected key in batch '{k}'")
padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
else:
# Set padding value based on the key
if k.endswith("_input_ids"):
if self.pad_token_id is None:
raise ValueError(
"Padding is enabled, but the tokenizer is not configured with a padding token."
" Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
" before calling the trainer."
)
padding_value = self.pad_token_id
elif k.endswith("_labels"):
padding_value = self.label_pad_token_id
elif k.endswith("_attention_mask"):
padding_value = 0
elif k.endswith("_pixel_values"):
padding_value = 0 # TODO: check if this is correct
else:
raise ValueError(f"Unexpected key in batch '{k}'")
# Set padding side based on the key
if k in ["prompt_input_ids", "prompt_attention_mask"]:
padding_side = "left"
else:
padding_side = "right"
# Set the dtype
if k.endswith("_pixel_values"):
dtype = torch.float32 # will be downcasted if necessary by the Trainer
else:
dtype = torch.int64
# Convert to tensor and pad
to_pad = [torch.tensor(ex[k], dtype=dtype) for ex in features]
padded_batch[k] = pad(to_pad, padding_value=padding_value, padding_side=padding_side)
elif k.endswith("_logps"):
# the cached reference model logprobs
padded_batch[k] = torch.tensor([ex[k] for ex in features])
else:
padded_batch[k] = [ex[k] for ex in features]
return padded_batch