in training/run_distillation.py [0:0]
def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
# dataloader returns a list of features which we convert to a dict
input_features = {"input_features": [feature["input_features"] for feature in features]}
label_features = {"input_ids": [feature["labels"] for feature in features]}
# reformat list to dict and set to pytorch format
batch = self.processor.feature_extractor.pad(
input_features,
padding=self.input_padding,
return_tensors="pt",
)
labels_batch = self.processor.tokenizer.pad(
label_features,
max_length=self.max_target_length,
padding=self.target_padding,
return_tensors="pt",
)
# shift labels to the right to get decoder input ids
labels = labels_batch["input_ids"]
decoder_input_ids = labels[:, :-1]
labels = labels[:, 1:]
labels_mask = labels_batch.attention_mask[:, 1:]
# replace padding with -100 to ignore correctly when computing the loss
labels = labels.masked_fill(labels_mask.ne(1), -100)
# replace initial prompt tokens with -100 to ignore correctly when computing the loss
bos_index = torch.argmax((labels == self.decoder_start_token_id).long(), dim=1)
bos_index = torch.where(bos_index > 0, bos_index + 1, bos_index)
prompt_mask = torch.arange(labels.shape[1]) < bos_index[:, None]
labels = torch.where(prompt_mask, -100, labels)
batch["labels"] = labels
batch["decoder_input_ids"] = decoder_input_ids
return batch