in Pretraining/dataclass.py [0:0]
def parse_one_dataset(self, dataset_prefix_path, data_set_name, train_test_mode, use_bs, use_da, use_nlg):
assert train_test_mode in ['train', 'test']
# train_test_mode: 'train' or 'test'
bs_exist, da_exist, nlg_exist = format_mapping_dict[data_set_name]['bs'], \
format_mapping_dict[data_set_name]['da'], format_mapping_dict[data_set_name]['nlg']
dataset_path = dataset_prefix_path + '/' + data_set_name + '_' + train_test_mode + '.json'
print ('Loading data from {}'.format(dataset_path))
with open(dataset_path) as f:
data = json.load(f)
all_sess_list = []
for one_sess in data:
dial_sess_list = one_sess["dialogue_session"] # this list contains all turns from on session
one_sess_list = []
# one_sess_list is a list of turns
# each turn is list of tuple pairs
previous_context = []
turn_num = len(dial_sess_list)
for turn_id in range(turn_num):
curr_turn = dial_sess_list[turn_id]
curr_turn_list = []
# this is a list of tuple pair (src, tgt)
# [(nlg_input, nlg_output), (bs_input, bs_output), (da_input, da_output)]
curr_user_input = curr_turn['user_id_list']
curr_sys_resp = curr_turn['resp_id_list']
# ----------------------------------------------------------- #
if use_nlg and nlg_exist: # adding nlg data into pre-training procedure
# construct nlg_input, nlg_output
nlg_input = previous_context + curr_user_input
nlg_input = self.nlg_prefix_id + [self.sos_context_token_id] + \
nlg_input[-900:] + [self.eos_context_token_id]
nlg_output = curr_sys_resp[:-1][:self.max_tgt_len] + [self.eos_r_token_id] # constrain the maximum tgt len
curr_turn_list.append((nlg_input, nlg_output))
if use_bs and bs_exist:
bs_input = previous_context + curr_user_input
bs_input = self.bs_prefix_id + [self.sos_context_token_id] + bs_input[-900:] + \
[self.eos_context_token_id]
curr_bspn = curr_turn['bspn_id_list']
bs_output = curr_bspn[:-1][:self.max_tgt_len] + [self.eos_b_token_id]
curr_turn_list.append((bs_input, bs_output))
if use_da and da_exist:
da_input = previous_context + curr_user_input
da_input = self.da_prefix_id + [self.sos_context_token_id] + da_input[-900:] + \
[self.eos_context_token_id]
curr_aspn = curr_turn['aspn_id_list']
da_output = curr_aspn[:-1][:self.max_tgt_len] + [self.eos_a_token_id]
curr_turn_list.append((da_input, da_output))
if len(curr_turn_list) > 0:
one_sess_list.append(curr_turn_list)
# update previous context
previous_context = previous_context + curr_user_input + curr_sys_resp
if len(one_sess_list) > 0:
all_sess_list.append(one_sess_list)
return all_sess_list