in modeling/dataset.py [0:0]
def collate_fn(self, batch):
input_ids = [example['input_ids'] for example in batch]
input_ids, attention_mask = self._pad(input_ids, self.pad_id)
input_ids, attention_mask = torch.tensor(input_ids).long().to(self.args.device), torch.tensor(attention_mask).long().to(self.args.device)
if not self.generation:
label_ids = [example['label_ids'] for example in batch]
label_ids, _ = self._pad(label_ids, -100)
label_ids = torch.tensor(label_ids).long().to(self.args.device)
mention_label_ids = [example['mention_label_ids'] for example in batch]
mention_label_ids, _ = self._pad(mention_label_ids, -100)
mention_label_ids = torch.tensor(mention_label_ids).long().to(self.args.device)
binary_label_ids = [example['binary_label_ids'] for example in batch]
binary_label_ids, _ = self._pad(binary_label_ids, -100)
binary_label_ids = torch.tensor(binary_label_ids).long().to(self.args.device)
else:
label_ids = None
mention_label_ids = [example['mention_label_ids'] for example in batch]
mention_label_ids, _ = self._pad(mention_label_ids, -100)
mention_label_ids = torch.tensor(mention_label_ids).long().to(self.args.device)
binary_label_ids = None
token_type_ids = None # TODO: not sure if this makes any effect to gpt2
# record info
context = [example['context'] for example in batch]
curr_utt = [example['curr_utt'] for example in batch]
rewt_utt = [example['rewt_utt'] for example in batch]
example_ids = [example['example_id'] for example in batch] # record the example idx in batch
curr_start_token_idx = [example['curr_start_token_idx'] for example in batch]
curr_end_token_idx = [example['curr_end_token_idx'] for example in batch]
reference_label = [example['reference_label'] for example in batch]
wordId2tokenId = [example['wordId2tokenId'] for example in batch]
tokenId2wordId = [example['tokenId2wordId'] for example in batch]
whole_input = [example['whole_input'] for example in batch]
spk = [example['spk'] for example in batch]
coref_label = [example['coref_label'] for example in batch]
binary_rewrite = [example['binary_rewrite'] for example in batch]
return {'input_ids': input_ids, 'attention_mask': attention_mask, \
'token_type_ids': token_type_ids, 'label_ids': label_ids, \
'context': context, 'curr_utt': curr_utt, 'rewt_utt': rewt_utt, \
'example_ids': example_ids, 'spk': spk, 'mention_label_ids': mention_label_ids, \
'curr_start_token_idx': curr_start_token_idx, 'curr_end_token_idx': curr_end_token_idx, \
'reference_label': reference_label, 'wordId2tokenId': wordId2tokenId, \
'tokenId2wordId': tokenId2wordId, 'whole_input': whole_input, \
'coref_label': coref_label, 'binary_label_ids': binary_label_ids, \
'binary_rewrite': binary_rewrite}