DST/dataclass.py [186:221]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                        one_turn_dict[key] = turn[key]
                    else:
                        # only tokenize ["user", "usdx", "resp", "bspn", "bsdx", "bspn_reform", "bsdx_reform"]
                        value_text = turn[key]
                        value_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(value_text))
                        value_id = self.replace_sos_eos_token_id(value_id)
                        one_turn_dict[key] = value_id
                one_sess_list.append(one_turn_dict)
            all_session_list.append(one_sess_list)
        p.finish()
        assert len(all_session_list) == len(raw_data_list)
        return all_session_list

    def shuffle_train_data(self):
        random.shuffle(self.train_data_list)

    def tokenized_decode(self, token_id_list):
        pred_tokens = self.tokenizer.convert_ids_to_tokens(token_id_list)
        res_text = ''
        curr_list = []
        for token in pred_tokens:
            if token in self.special_token_list + ['<s>', '</s>', '<pad>']:
                if len(curr_list) == 0:
                    res_text += ' ' + token + ' '
                else:
                    curr_res = self.tokenizer.convert_tokens_to_string(curr_list)
                    res_text = res_text + ' ' + curr_res + ' ' + token + ' '
                    curr_list = []
            else:
                curr_list.append(token)
        if len(curr_list) > 0:
            curr_res = self.tokenizer.convert_tokens_to_string(curr_list)
            res_text = res_text + ' ' + curr_res + ' '
        res_text_list = res_text.strip().split()
        res_text = ' '.join(res_text_list).strip()
        return res_text
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



E2E_TOD/dataclass.py [198:232]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                        one_turn_dict[key] = turn[key]
                    else:
                        value_text = turn[key]
                        value_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(value_text))
                        value_id = self.replace_sos_eos_token_id(value_id)
                        one_turn_dict[key] = value_id
                one_sess_list.append(one_turn_dict)
            all_session_list.append(one_sess_list)
        p.finish()
        assert len(all_session_list) == len(raw_data_list)
        return all_session_list

    def shuffle_train_data(self):
        random.shuffle(self.train_data_list)

    def tokenized_decode(self, token_id_list):
        pred_tokens = self.tokenizer.convert_ids_to_tokens(token_id_list)
        res_text = ''
        curr_list = []
        for token in pred_tokens:
            if token in self.special_token_list + ['<s>', '</s>', '<pad>']:
                if len(curr_list) == 0:
                    res_text += ' ' + token + ' '
                else:
                    curr_res = self.tokenizer.convert_tokens_to_string(curr_list)
                    res_text = res_text + ' ' + curr_res + ' ' + token + ' '
                    curr_list = []
            else:
                curr_list.append(token)
        if len(curr_list) > 0:
            curr_res = self.tokenizer.convert_tokens_to_string(curr_list)
            res_text = res_text + ' ' + curr_res + ' '
        res_text_list = res_text.strip().split()
        res_text = ' '.join(res_text_list).strip()
        return res_text
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



