DST/dataclass.py [131:184]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        if data_mode == 'train':
            print ('train turn number is %d, dev turn number is %d, test turn number is %d' % \
                (len(self.train_data_list), len(self.dev_data_list), len(self.test_data_list)))
            self.shuffle_mode = shuffle_mode
            self.ordering_train_data()
        else:
            pass

    def ordering_train_data(self):
        if self.shuffle_mode == 'shuffle_turn_level':
            random.shuffle(self.train_data_list)
        elif self.shuffle_mode == 'shuffle_session_level':
            train_data_list = []
            random.shuffle(self.train_dial_id_list)
            for dial_id in self.train_dial_id_list:
                one_session_list = self.train_id2session_dict[dial_id]
                for one_turn in one_session_list:
                    train_data_list.append(one_turn)
            assert len(train_data_list) == len(self.train_data_list)
            self.train_data_list = train_data_list
        elif self.shuffle_mode == 'unshuffle':
            pass
        else:
            raise Exception('Wrong Train Ordering Mode!!!')

    def replace_sos_eos_token_id(self, token_id_list):
        if self.add_special_decoder_token: # if adding special decoder tokens, then no replacement
            sos_token_id_list = []
            eos_token_id_list = []
        else:
            sos_token_id_list = self.all_sos_token_id_list
            eos_token_id_list = self.all_eos_token_id_list

        res_token_id_list = []
        for one_id in token_id_list:
            if one_id in sos_token_id_list:
                res_token_id_list.append(self.bos_token_id)
            elif one_id in eos_token_id_list:
                res_token_id_list.append(self.eos_token_id)
            else:
                res_token_id_list.append(one_id)
        return res_token_id_list

    def tokenize_raw_data(self, raw_data_list):
        data_num = len(raw_data_list)
        p = progressbar.ProgressBar(data_num)
        p.start()
        all_session_list = []
        for idx in range(data_num):
            p.update(idx)
            one_sess_list = []
            for turn in raw_data_list[idx]:
                one_turn_dict = {}
                for key in turn:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



E2E_TOD/dataclass.py [143:196]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        if data_mode == 'train':
            print ('train turn number is %d, dev turn number is %d, test turn number is %d' % \
                (len(self.train_data_list), len(self.dev_data_list), len(self.test_data_list)))
            self.shuffle_mode = shuffle_mode
            self.ordering_train_data()
        else:
            pass

    def ordering_train_data(self):
        if self.shuffle_mode == 'shuffle_turn_level':
            random.shuffle(self.train_data_list)
        elif self.shuffle_mode == 'shuffle_session_level':
            train_data_list = []
            random.shuffle(self.train_dial_id_list)
            for dial_id in self.train_dial_id_list:
                one_session_list = self.train_id2session_dict[dial_id]
                for one_turn in one_session_list:
                    train_data_list.append(one_turn)
            assert len(train_data_list) == len(self.train_data_list)
            self.train_data_list = train_data_list
        elif self.shuffle_mode == 'unshuffle':
            pass
        else:
            raise Exception('Wrong Train Ordering Mode!!!')

    def replace_sos_eos_token_id(self, token_id_list):
        if self.add_special_decoder_token: # if adding special decoder tokens, then no replacement
            sos_token_id_list = []
            eos_token_id_list = []
        else:
            sos_token_id_list = self.all_sos_token_id_list
            eos_token_id_list = self.all_eos_token_id_list

        res_token_id_list = []
        for one_id in token_id_list:
            if one_id in sos_token_id_list:
                res_token_id_list.append(self.bos_token_id)
            elif one_id in eos_token_id_list:
                res_token_id_list.append(self.eos_token_id)
            else:
                res_token_id_list.append(one_id)
        return res_token_id_list

    def tokenize_raw_data(self, raw_data_list):
        data_num = len(raw_data_list)
        p = progressbar.ProgressBar(data_num)
        p.start()
        all_session_list = []
        for idx in range(data_num):
            p.update(idx)
            one_sess_list = []
            for turn in raw_data_list[idx]:
                one_turn_dict = {}
                for key in turn:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



