def __init__()

in E2E_TOD/dataclass.py [0:0]


    def __init__(self, model_name, tokenizer, cfg, data_path_prefix, shuffle_mode='shuffle_session_level', 
        data_mode='train', use_db_as_input=True, add_special_decoder_token=True, train_data_ratio=1.0):
        '''
            model_name: t5-small or t5-base or t5-large

            use_db_as_input: controls whether we use db as input to generate the response/dialogue action

            we consider learning several tasks (i.e. belieft state generation, dialogue action generation,
            dialogue response generation) with the same model. And we want to break the sequential dependencies
            that exist in several tasks.

            data_path_prefix: where the data stores

            shuffle_mode: turn level shuffle or session level shuffle

            reform_db_and_act: whether use reformed bs and act to do the training

            add_special_decoder_token: whether add special decoder token for generation of each subtasks
                    <sos_b>, <eos_b> for belief state tracking
                    <sos_a>, <eos_a> for dialogue action prediction
                    <sos_r>, <eos_r> for response generation
        '''
        self.use_db_as_input = use_db_as_input
        assert self.use_db_as_input in [True, False]
        self.add_special_decoder_token = add_special_decoder_token
        assert self.add_special_decoder_token in [True, False]

        self.cfg = cfg
        self.vocab = self._build_vocab(self.cfg)
        self.tokenizer = tokenizer
        print ('Original Tokenizer Size is %d' % len(self.tokenizer))
        self.special_token_list = self.add_sepcial_tokens()
        print ('Tokenizer Size after extension is %d' % len(self.tokenizer))
        self.pad_token_id = self.tokenizer.convert_tokens_to_ids(['<_PAD_>'])[0]
        self.sos_context_token_id = self.tokenizer.convert_tokens_to_ids(['<sos_context>'])[0]
        self.eos_context_token_id = self.tokenizer.convert_tokens_to_ids(['<eos_context>'])[0]
        self.reader = MultiWozReader(self.tokenizer, self.cfg, data_mode='test')

        # initialize bos_token_id, eos_token_id
        self.model_name = model_name
        if model_name.startswith('t5'):
            from transformers import T5Config
            t5config = T5Config.from_pretrained(model_name)
            self.bos_token_id = t5config.decoder_start_token_id
            self.eos_token_id = self.tokenizer.eos_token_id
        else:
            raise Exception('Wrong Model Name!!!')
        self.bos_token = self.tokenizer.convert_ids_to_tokens([self.bos_token_id])[0]
        self.eos_token = self.tokenizer.convert_ids_to_tokens([self.eos_token_id])[0]
        print ('bos token is {}, eos token is {}'.format(self.bos_token, self.eos_token))

        self.all_sos_token_id_list = []
        for token in all_sos_token_list:
            one_id = self.tokenizer.convert_tokens_to_ids([token])[0]
            self.all_sos_token_id_list.append(one_id)
            print (self.tokenizer.convert_ids_to_tokens([one_id]))
        print (len(self.all_sos_token_id_list))
        self.all_eos_token_id_list = []
        for token in all_eos_token_list:
            one_id = self.tokenizer.convert_tokens_to_ids([token])[0]
            self.all_eos_token_id_list.append(one_id)
            print (self.tokenizer.convert_ids_to_tokens([one_id]))
        print (len(self.all_eos_token_id_list))

        bs_prefix_text = 'translate dialogue to belief state:'
        self.bs_prefix_id = self.tokenizer.convert_tokens_to_ids(tokenizer.tokenize(bs_prefix_text))
        da_prefix_text = 'translate dialogue to dialogue action:'
        self.da_prefix_id = self.tokenizer.convert_tokens_to_ids(tokenizer.tokenize(da_prefix_text))
        nlg_prefix_text = 'translate dialogue to system response:'
        self.nlg_prefix_id = self.tokenizer.convert_tokens_to_ids(tokenizer.tokenize(nlg_prefix_text))

        import json
        if data_mode == 'train':
            train_json_path = data_path_prefix + '/multi-woz-fine-processed/multiwoz-fine-processed-train.json'
            with open(train_json_path) as f:
                train_raw_data = json.load(f)

            self.train_data_ratio = train_data_ratio
            assert self.train_data_ratio > 0
            # few-shot learning
            if self.train_data_ratio < 1.0:
                print ('Few-shot training setup.')
                few_shot_num = int(len(train_raw_data) * self.train_data_ratio) + 1
                random.shuffle(train_raw_data)
                # randomly select a subset of training data
                train_raw_data = train_raw_data[:few_shot_num]
                print ('Number of training sessions is {}'.format(few_shot_num))
            print ('Tokenizing raw train data...')
            train_data_id_list = self.tokenize_raw_data(train_raw_data)
            self.train_data_list = self.flatten_data(train_data_id_list)
            # record training data
            self.train_id2session_dict = {}
            self.train_dial_id_list = []
            for item in self.train_data_list:
                one_item_id = item['dial_id']
                try:
                    self.train_id2session_dict[one_item_id].append(item)
                except KeyError:
                    self.train_dial_id_list.append(one_item_id)
                    self.train_id2session_dict[one_item_id] = [item]
            assert len(self.train_dial_id_list) == len(self.train_id2session_dict)
            self.train_num = len(self.train_data_list) * 3 # bs, da, nlg
        elif data_mode == 'test':
            train_raw_data = []
        else:
            raise Exception('Wrong Data Mode!!!')

        dev_json_path = data_path_prefix + '/multi-woz-fine-processed/multiwoz-fine-processed-dev.json'
        with open(dev_json_path) as f:
            dev_raw_data = json.load(f)
        print ('Tokenizing raw dev data...')
        dev_data_id_list = self.tokenize_raw_data(dev_raw_data)
        self.dev_data_list = self.flatten_data(dev_data_id_list)

        test_json_path = data_path_prefix + '/multi-woz-fine-processed/multiwoz-fine-processed-test.json'
        with open(test_json_path) as f:
            test_raw_data = json.load(f)
        print ('Tokenizing raw test data...')
        test_data_id_list = self.tokenize_raw_data(test_raw_data)
        self.test_data_list = self.flatten_data(test_data_id_list)

        print ('The size of raw train, dev and test sets are %d, %d and %d' % \
            (len(train_raw_data), len(dev_raw_data), len(test_raw_data)))

        self.dev_num, self.test_num = len(self.dev_data_list) * 3, len(self.test_data_list) * 3
        if data_mode == 'train':
            print ('train turn number is %d, dev turn number is %d, test turn number is %d' % \
                (len(self.train_data_list), len(self.dev_data_list), len(self.test_data_list)))
            self.shuffle_mode = shuffle_mode
            self.ordering_train_data()
        else:
            pass