in problem.py [0:0]
def encode_data_list(self, data_list, file_columns, input_types, object_inputs, answer_column_name, min_sentence_len,
extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None, predict_mode='batch'):
data = dict()
lengths = dict()
char_emb = True if 'char' in [single_input_type.lower() for single_input_type in input_types] else False
if answer_column_name is not None and len(answer_column_name)>0:
target = {}
lengths['target'] = {}
columns_to_target = {}
for single_target in answer_column_name:
target[single_target] = []
columns_to_target[file_columns[single_target]] = single_target
lengths['target'][single_target] = []
else:
target = None
col_index_types = dict() # input type of each column, namely the inverse of file_columns, e.g. col_index_types[0] = 'query_index'
type2cluster = dict() # e.g. type2cluster['query_index'] = 'word'
type_branches = dict() # branch of input type, e.g. type_branches['query_index'] = 'query'
# for char: don't split these word
word_no_split = ['<start>', '<pad>', '<eos>', '<unk>']
for branch in object_inputs:
data[branch] = dict()
lengths[branch] = dict()
lengths[branch]['sentence_length'] = []
temp_branch_char = False
for input_type in object_inputs[branch]:
type_branches[input_type] = branch
data[branch][input_type] = []
if 'char' in input_type.lower():
temp_branch_char = True
if char_emb and temp_branch_char:
lengths[branch]['word_length'] = []
# for extra_info for mrc task
if ProblemTypes[self.problem_type] == ProblemTypes.mrc:
extra_info_type = 'passage'
if extra_info_type not in object_inputs:
raise Exception('MRC task need passage for model_inputs, given: {0}'.format(';'.join(list(object_inputs.keys()))))
data[extra_info_type]['extra_passage_text'] = []
data[extra_info_type]['extra_passage_token_offsets'] = []
for input_type in input_types:
for col_name in input_types[input_type]['cols']:
type2cluster[col_name] = input_type
if col_name in file_columns:
col_index_types[file_columns[col_name]] = col_name
cnt_legal = 0
cnt_illegal = 0
# cnt_length_unconsistent = 0
cnt_all = 0
for line in data_list:
# line_split = list(filter(lambda x: len(x) > 0, line.rstrip().split('\t')))
line_split = line.rstrip().split('\t')
cnt_all += 1
if len(line_split) != len(file_columns):
if predict_mode == 'batch':
cnt_illegal += 1
if cnt_illegal / cnt_all > 0.33:
raise PreprocessError('The illegal data is too much. Please check the number of data columns or text token version.')
continue
else:
print('\tThe case is illegal! Please check your case and input again!')
return [None]*5
# cnt_legal += 1
length_appended_set = set() # to store branches whose length have been appended to lengths[branch]
if ProblemTypes[self.problem_type] == ProblemTypes.mrc:
passage_token_offsets = None
for i in range(len(line_split)):
line_split[i] = line_split[i].strip()
if i in col_index_types:
# these are data
branch = type_branches[col_index_types[i]]
input_type = []
input_type.append(col_index_types[i])
if(type2cluster[col_index_types[i]] == 'word' and char_emb):
temp_col_char = col_index_types[i].split('_')[0] + '_' + 'char'
if temp_col_char in input_types['char']['cols']:
input_type.append(temp_col_char)
if type2cluster[col_index_types[i]] == 'word' or type2cluster[col_index_types[i]] == 'bpe':
if self.lowercase:
line_split[i] = line_split[i].lower()
line_split[i] = self.text_preprocessor.preprocess(line_split[i])
if type2cluster[col_index_types[i]] == 'word':
if ProblemTypes[self.problem_type] == ProblemTypes.mrc:
token_offsets = self.tokenizer.span_tokenize(line_split[i])
tokens = [line_split[i][span[0]:span[1]] for span in token_offsets]
if branch == 'passage':
passage_token_offsets = token_offsets
data[extra_info_type]['extra_passage_text'].append(line_split[i])
data[extra_info_type]['extra_passage_token_offsets'].append(passage_token_offsets)
else:
if extra_feature == False and ProblemTypes[self.problem_type] != ProblemTypes.sequence_tagging:
tokens = self.tokenizer.tokenize(line_split[i])
else:
tokens = line_split[i].split(' ')
elif type2cluster[col_index_types[i]] == 'bpe':
tokens = bpe_encoder.encode(line_split[i])
else:
tokens = line_split[i].split(' ')
# for sequence labeling task, the length must be record the corpus truth length
if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
if not branch in length_appended_set:
lengths[branch]['sentence_length'].append(len(tokens))
length_appended_set.add(branch)
else:
if len(tokens) != lengths[branch]['sentence_length'][-1]:
# logging.warning(
# "The length of inputs are not consistent. Ingore now. %s" % line)
cnt_illegal += 1
if cnt_illegal / cnt_all > 0.33:
raise PreprocessError(
"The illegal data is too much. Please check the number of data columns or text token version.")
lengths[branch]['sentence_length'].pop()
true_len = len(lengths[branch]['sentence_length'])
# need delete the last example
check_list = ['data', 'lengths', 'target']
for single_check in check_list:
single_check = eval(single_check)
self.delete_example(single_check, true_len)
break
if fixed_lengths and type_branches[input_type[0]] in fixed_lengths:
if len(tokens) >= fixed_lengths[type_branches[input_type[0]]]:
tokens = tokens[:fixed_lengths[type_branches[input_type[0]]]]
else:
tokens = tokens + ['<pad>'] * (fixed_lengths[type_branches[input_type[0]]] - len(tokens))
else:
if max_lengths and type_branches[input_type[0]] in max_lengths: # cut sequences which are too long
tokens = tokens[:max_lengths[type_branches[input_type[0]]]]
if len(tokens) < min_sentence_len:
tokens = tokens + ['<pad>'] * (min_sentence_len - len(tokens))
if self.with_bos_eos is True:
tokens = ['<start>'] + tokens + ['<eos>'] # so that source_with_start && source_with_end should be True
# for other tasks, length must be same as data length because fix/max_length operation
if not ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
if not branch in length_appended_set:
lengths[branch]['sentence_length'].append(len(tokens))
length_appended_set.add(branch)
else:
if len(tokens) != lengths[branch]['sentence_length'][-1]:
# logging.warning(
# "The length of inputs are not consistent. Ingore now. %s" % line)
cnt_illegal += 1
if cnt_illegal / cnt_all > 0.33:
raise PreprocessError(
"The illegal data is too much. Please check the number of data columns or text token version.")
lengths[branch]['sentence_length'].pop()
true_len = len(lengths[branch]['sentence_length'])
# need delete the last example
check_list = ['data', 'lengths', 'target']
for single_check in check_list:
single_check = eval(single_check)
self.delete_example(single_check, true_len)
break
for single_input_type in input_type:
if 'char' in single_input_type:
temp_word_char = []
temp_word_length = []
for single_token in tokens:
if single_token in word_no_split:
# temp_word_length.append(1)
temp_id = [self.input_dicts[type2cluster[single_input_type]].id(single_token)]
else:
temp_id = self.input_dicts[type2cluster[single_input_type]].lookup(single_token)
if fixed_lengths and 'word' in fixed_lengths:
if len(temp_id) >= fixed_lengths['word']:
temp_id = temp_id[:fixed_lengths['word']]
else:
temp_id = temp_id + [self.input_dicts[type2cluster[single_input_type]].id('<pad>')] * (fixed_lengths['word'] - len(temp_id))
temp_word_char.append(temp_id)
temp_word_length.append(len(temp_id))
data[branch][single_input_type].append(temp_word_char)
lengths[branch]['word_length'].append(temp_word_length)
else:
data[branch][single_input_type].\
append(self.input_dicts[type2cluster[single_input_type]].lookup(tokens))
else:
# judge target
if answer_column_name is not None and len(answer_column_name) > 0:
if i in columns_to_target.keys():
# this is target
curr_target = columns_to_target[i]
if ProblemTypes[self.problem_type] == ProblemTypes.mrc:
try:
trans2int = int(line_split[i])
except(ValueError):
target[curr_target].append(line_split[i])
else:
target[curr_target].append(trans2int)
lengths['target'][curr_target].append(1)
if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
target_tags = line_split[i].split(" ")
if fixed_lengths and "target" in fixed_lengths:
if len(target_tags) >= fixed_lengths[type_branches[input_type[0]]]:
target_tags = target_tags[:fixed_lengths[type_branches[input_type[0]]]]
else:
target_tags = target_tags + ['<pad>'] * (fixed_lengths[type_branches[input_type[0]]] - len(target_tags))
else:
if max_lengths and "target" in max_lengths: # cut sequences which are too long
target_tags = target_tags[:max_lengths["target"]]
if self.with_bos_eos is True:
target_tags = ['O'] + target_tags + ['O']
target[curr_target].append(self.output_dict.lookup(target_tags))
lengths['target'][curr_target].append(len(target_tags))
elif ProblemTypes[self.problem_type] == ProblemTypes.classification:
target[curr_target].append(self.output_dict.id(line_split[i]))
lengths['target'][curr_target].append(1)
elif ProblemTypes[self.problem_type] == ProblemTypes.regression:
target[curr_target].append(float(line_split[i]))
lengths['target'][curr_target].append(1)
else:
# these columns are useless in the configuration
pass
cnt_legal += 1
if ProblemTypes[self.problem_type] == ProblemTypes.mrc and target is not None:
if passage_token_offsets:
if 'start_label' not in target or 'end_label' not in target:
raise Exception('MRC task need start_label and end_label.')
start_char_label = target['start_label'][-1]
end_char_label = target['end_label'][-1]
start_word_label = 0
end_word_label = len(passage_token_offsets) - 1
# for i in range(len(passage_token_offsets)):
# token_s, token_e = passage_token_offsets[i]
# if token_s > start_char_label:
# break
# start_word_label = i
# for i in range(len(passage_token_offsets)):
# token_s, token_e = passage_token_offsets[i]
# end_word_label = i
# if token_e >= end_char_label:
# break
for i in range(len(passage_token_offsets)):
token_s, token_e = passage_token_offsets[i]
if token_s <= start_char_label <= token_e:
start_word_label = i
if token_s <= end_char_label - 1 <= token_e:
end_word_label = i
target['start_label'][-1] = start_word_label
target['end_label'][-1] = end_word_label
else:
raise Exception('MRC task need passage.')
return data, lengths, target, cnt_legal, cnt_illegal