in modeling/dataset.py [0:0]
def _create_examples(self):
if self.data_type == 'train':
data_file = self.args.train_file
elif self.data_type == 'dev':
data_file = self.args.dev_file
else:
data_file = self.args.test_file
with open(data_file) as f:
data = json.load(f)
self.examples = []
for example_num, example in enumerate(tqdm(data, disable=self.args.disable_display)):
if self.data_size != -1 and example_num == self.data_size:
break
# get data
context = example['dialogue context'] # context, list of str
curr_utt = example['current utterance'] # current utterance, str
rewt_utt = example['rewrite utterance'] # rewrite utterance, str
word_label_index = example['link index'] # index of mention/reference span
binary_rewrite = example['rewrite happen'] # binary label for rewrite or not, bool
# prepare input sequence to model
whole_input = copy.deepcopy(context)
whole_input.append(curr_utt)
curr_start_idx = sum([len(s.split()) for s in context]) # the (word) start idx of current utt
curr_end_idx = curr_start_idx + len(curr_utt.split())
whole_input = " ".join(whole_input)
self._check_label_index(whole_input, word_label_index)
input_ids, wordId2tokenId, tokenId2wordId = self.tokenize_with_map(whole_input)
if rewt_utt == "":
rewt_utt_ids = []
else:
rewt_utt_ids = self.tokenizer(rewt_utt)['input_ids'] # list
target_utt_ids = rewt_utt_ids
target_utt_len = len(target_utt_ids)
if not self.generation:
# input seq: CTX <CUR> current utterance <SEP> rewritten utterance <EOS>
input_ids = input_ids + [self.sep_id] + target_utt_ids + [self.eos_id]
# mention detection signal
mention_label, curr_start_token_idx, curr_end_token_idx = \
self.prepare_mention_label(input_ids, word_label_index, wordId2tokenId, curr_start_idx, curr_end_idx)
# reference resolution signal
reference_label_index = self.prepare_reference_label(word_label_index, wordId2tokenId, input_ids)
# binary classification of rewriting signal
binary_label = self.prepare_binary_label(input_ids, wordId2tokenId, binary_rewrite, curr_end_token_idx)
# rewriting singal
ignore_len = len(input_ids) - target_utt_len - 1 # eos_id
label_ids = [-100] * ignore_len + target_utt_ids + [self.eos_id]
assert len(input_ids) == len(label_ids)
else: # generation
# <sep> is given at first step during decoding
input_ids = input_ids
label_ids = None
mention_label, curr_start_token_idx, curr_end_token_idx = \
self.prepare_mention_label(input_ids, word_label_index, wordId2tokenId, curr_start_idx, curr_end_idx)
reference_label_index = self.prepare_reference_label(word_label_index, wordId2tokenId, input_ids)
binary_label = None
self.examples.append({
'input_ids': input_ids, # list of ids
'label_ids': label_ids, # list of ids
'mention_label_ids': mention_label,
'curr_start_token_idx': curr_start_token_idx,
'curr_end_token_idx': curr_end_token_idx,
'reference_label': reference_label_index,
'wordId2tokenId': wordId2tokenId,
'tokenId2wordId': tokenId2wordId,
'context': context,
'curr_utt': curr_utt,
'whole_input': whole_input,
'rewt_utt': rewt_utt,
'example_id': example['example index'],
'spk': example['speaker'],
'coref_label': word_label_index,
'binary_label_ids': binary_label,
'binary_rewrite': binary_rewrite
})
print('Data Statistics: {} -> {} examples'.format(self.data_type, len(self.examples)))