def collate_fn()

in modeling/dataset.py [0:0]


	def collate_fn(self, batch):
		input_ids = [example['input_ids'] for example in batch]
		input_ids, attention_mask = self._pad(input_ids, self.pad_id)
		input_ids, attention_mask = torch.tensor(input_ids).long().to(self.args.device), torch.tensor(attention_mask).long().to(self.args.device)

		if not self.generation:
			label_ids = [example['label_ids'] for example in batch]
			label_ids, _ = self._pad(label_ids, -100)
			label_ids = torch.tensor(label_ids).long().to(self.args.device)
			mention_label_ids = [example['mention_label_ids'] for example in batch]
			mention_label_ids, _ = self._pad(mention_label_ids, -100)
			mention_label_ids = torch.tensor(mention_label_ids).long().to(self.args.device)
			binary_label_ids = [example['binary_label_ids'] for example in batch]
			binary_label_ids, _ = self._pad(binary_label_ids, -100)
			binary_label_ids = torch.tensor(binary_label_ids).long().to(self.args.device)
		else:
			label_ids = None
			mention_label_ids = [example['mention_label_ids'] for example in batch]
			mention_label_ids, _ = self._pad(mention_label_ids, -100)
			mention_label_ids = torch.tensor(mention_label_ids).long().to(self.args.device)
			binary_label_ids = None
		token_type_ids = None # TODO: not sure if this makes any effect to gpt2

		# record info
		context = [example['context'] for example in batch]
		curr_utt = [example['curr_utt'] for example in batch]
		rewt_utt = [example['rewt_utt'] for example in batch]
		example_ids = [example['example_id'] for example in batch] # record the example idx in batch
		curr_start_token_idx = [example['curr_start_token_idx'] for example in batch]
		curr_end_token_idx = [example['curr_end_token_idx'] for example in batch]
		reference_label = [example['reference_label'] for example in batch]
		wordId2tokenId = [example['wordId2tokenId'] for example in batch]
		tokenId2wordId = [example['tokenId2wordId'] for example in batch]
		whole_input = [example['whole_input'] for example in batch]
		spk = [example['spk'] for example in batch]
		coref_label = [example['coref_label'] for example in batch]
		binary_rewrite = [example['binary_rewrite'] for example in batch]

		return {'input_ids': input_ids, 'attention_mask': attention_mask, \
				'token_type_ids': token_type_ids, 'label_ids': label_ids, \
				'context': context, 'curr_utt': curr_utt, 'rewt_utt': rewt_utt, \
				'example_ids': example_ids, 'spk': spk, 'mention_label_ids': mention_label_ids, \
				'curr_start_token_idx': curr_start_token_idx, 'curr_end_token_idx': curr_end_token_idx, \
				'reference_label': reference_label, 'wordId2tokenId': wordId2tokenId, \
				'tokenId2wordId': tokenId2wordId, 'whole_input': whole_input, \
				'coref_label': coref_label, 'binary_label_ids': binary_label_ids, \
				'binary_rewrite': binary_rewrite}