def parse_one_dataset()

in Pretraining/dataclass.py [0:0]


    def parse_one_dataset(self, dataset_prefix_path, data_set_name, train_test_mode, use_bs, use_da, use_nlg):
        assert train_test_mode in ['train', 'test']
        # train_test_mode: 'train' or 'test'
        bs_exist, da_exist, nlg_exist = format_mapping_dict[data_set_name]['bs'], \
        format_mapping_dict[data_set_name]['da'], format_mapping_dict[data_set_name]['nlg']
        dataset_path = dataset_prefix_path + '/' + data_set_name + '_' + train_test_mode + '.json'

        print ('Loading data from {}'.format(dataset_path))
        with open(dataset_path) as f:
            data = json.load(f) 

        all_sess_list = []
        for one_sess in data:
            dial_sess_list = one_sess["dialogue_session"] # this list contains all turns from on session
            one_sess_list = [] 
            # one_sess_list is a list of turns
            # each turn is list of tuple pairs
            previous_context = []
            turn_num = len(dial_sess_list)
            for turn_id in range(turn_num):
                curr_turn = dial_sess_list[turn_id]
                curr_turn_list = []
                # this is a list of tuple pair (src, tgt)
                # [(nlg_input, nlg_output), (bs_input, bs_output), (da_input, da_output)]
                curr_user_input = curr_turn['user_id_list']
                curr_sys_resp = curr_turn['resp_id_list']
                # ----------------------------------------------------------- #
                if use_nlg and nlg_exist: # adding nlg data into pre-training procedure
                    # construct nlg_input, nlg_output
                    nlg_input = previous_context + curr_user_input 
                    nlg_input = self.nlg_prefix_id + [self.sos_context_token_id] + \
                    nlg_input[-900:] + [self.eos_context_token_id]
                    nlg_output = curr_sys_resp[:-1][:self.max_tgt_len] + [self.eos_r_token_id] # constrain the maximum tgt len
                    curr_turn_list.append((nlg_input, nlg_output))

                if use_bs and bs_exist:
                    bs_input = previous_context + curr_user_input
                    bs_input = self.bs_prefix_id + [self.sos_context_token_id] + bs_input[-900:] + \
                    [self.eos_context_token_id]
                    curr_bspn = curr_turn['bspn_id_list']
                    bs_output = curr_bspn[:-1][:self.max_tgt_len] + [self.eos_b_token_id]
                    curr_turn_list.append((bs_input, bs_output))

                if use_da and da_exist:
                    da_input = previous_context + curr_user_input 
                    da_input = self.da_prefix_id + [self.sos_context_token_id] + da_input[-900:] + \
                        [self.eos_context_token_id]
                    curr_aspn = curr_turn['aspn_id_list']
                    da_output = curr_aspn[:-1][:self.max_tgt_len] + [self.eos_a_token_id]
                    curr_turn_list.append((da_input, da_output))

                if len(curr_turn_list) > 0:
                    one_sess_list.append(curr_turn_list)
                # update previous context
                previous_context = previous_context + curr_user_input + curr_sys_resp
            if len(one_sess_list) > 0:
                all_sess_list.append(one_sess_list)
        return all_sess_list