def _create_examples()

in modeling/dataset.py [0:0]


	def _create_examples(self):
		if self.data_type == 'train':
			data_file = self.args.train_file
		elif self.data_type == 'dev':
			data_file = self.args.dev_file
		else:
			data_file = self.args.test_file

		with open(data_file) as f:
			data = json.load(f)

		self.examples = []
		for example_num, example in enumerate(tqdm(data, disable=self.args.disable_display)):
			if self.data_size != -1 and example_num == self.data_size:
				break

			# get data
			context = example['dialogue context'] # context, list of str
			curr_utt = example['current utterance'] # current utterance, str
			rewt_utt = example['rewrite utterance'] # rewrite utterance, str
			word_label_index = example['link index'] # index of mention/reference span
			binary_rewrite = example['rewrite happen'] # binary label for rewrite or not, bool

			# prepare input sequence to model
			whole_input = copy.deepcopy(context)
			whole_input.append(curr_utt)
			curr_start_idx = sum([len(s.split()) for s in context]) # the (word) start idx of current utt
			curr_end_idx = curr_start_idx + len(curr_utt.split())

			whole_input = " ".join(whole_input)
			self._check_label_index(whole_input, word_label_index)
			input_ids, wordId2tokenId, tokenId2wordId = self.tokenize_with_map(whole_input)

			if rewt_utt == "":
				rewt_utt_ids = []
			else:
				rewt_utt_ids = self.tokenizer(rewt_utt)['input_ids'] # list
			target_utt_ids = rewt_utt_ids
			target_utt_len = len(target_utt_ids)

			if not self.generation:
				# input seq: CTX <CUR> current utterance <SEP> rewritten utterance <EOS>
				input_ids = input_ids + [self.sep_id] + target_utt_ids + [self.eos_id]

				# mention detection signal
				mention_label, curr_start_token_idx, curr_end_token_idx = \
					self.prepare_mention_label(input_ids, word_label_index, wordId2tokenId, curr_start_idx, curr_end_idx)

				# reference resolution signal
				reference_label_index = self.prepare_reference_label(word_label_index, wordId2tokenId, input_ids)

				# binary classification of rewriting signal
				binary_label = self.prepare_binary_label(input_ids, wordId2tokenId, binary_rewrite, curr_end_token_idx)

				# rewriting singal
				ignore_len = len(input_ids) - target_utt_len - 1 # eos_id
				label_ids = [-100] * ignore_len + target_utt_ids + [self.eos_id]
				assert len(input_ids) == len(label_ids)

			else: # generation
				# <sep> is given at first step during decoding
				input_ids = input_ids
				label_ids = None
				mention_label, curr_start_token_idx, curr_end_token_idx = \
						self.prepare_mention_label(input_ids, word_label_index, wordId2tokenId, curr_start_idx, curr_end_idx)
				reference_label_index = self.prepare_reference_label(word_label_index, wordId2tokenId, input_ids)
				binary_label = None

			self.examples.append({
				'input_ids': input_ids, # list of ids
				'label_ids': label_ids, # list of ids
				'mention_label_ids': mention_label,
				'curr_start_token_idx': curr_start_token_idx,
				'curr_end_token_idx': curr_end_token_idx,
				'reference_label': reference_label_index,
				'wordId2tokenId': wordId2tokenId,
				'tokenId2wordId': tokenId2wordId,
				'context': context,
				'curr_utt': curr_utt,
				'whole_input': whole_input,
				'rewt_utt': rewt_utt,
				'example_id': example['example index'],
				'spk': example['speaker'],
				'coref_label': word_label_index,
				'binary_label_ids': binary_label,
				'binary_rewrite': binary_rewrite
			})

		print('Data Statistics: {} -> {} examples'.format(self.data_type, len(self.examples)))