def __init__()

in DST/dataclass.py [0:0]


    def __init__(self, model_name, tokenizer, data_path_prefix, shuffle_mode='shuffle_session_level', 
        data_mode='train', add_prefix=True, add_special_decoder_token=True, train_data_ratio=1.0):
        '''
            model_name: t5-small or t5-base or t5-large

            data_path_prefix: where the data stores

            shuffle_mode: turn level shuffle or session level shuffle

            add_prefix: whether adding task-specifc prompt to drive the generation

            add_special_decoder_token: whether add special decoder token for generation of each subtasks
                    <sos_b>, <eos_b> for belief state tracking
                    <sos_a>, <eos_a> for dialogue action prediction
                    <sos_r>, <eos_r> for response generation
        '''
        self.add_prefix = add_prefix
        assert self.add_prefix in [True, False]
        self.add_special_decoder_token = add_special_decoder_token
        assert self.add_special_decoder_token in [True, False]

        self.tokenizer = tokenizer
        print ('Original Tokenizer Size is %d' % len(self.tokenizer))
        self.special_token_list = self.add_sepcial_tokens()
        print ('Tokenizer Size after extension is %d' % len(self.tokenizer))
        self.pad_token_id = self.tokenizer.convert_tokens_to_ids(['<_PAD_>'])[0]
        self.sos_context_token_id = self.tokenizer.convert_tokens_to_ids(['<sos_context>'])[0]
        self.eos_context_token_id = self.tokenizer.convert_tokens_to_ids(['<eos_context>'])[0]

        # initialize bos_token_id, eos_token_id
        self.model_name = model_name
        assert self.model_name.startswith('t5')
        from transformers import T5Config
        t5config = T5Config.from_pretrained(model_name)
        self.bos_token_id = t5config.decoder_start_token_id
        self.eos_token_id = self.tokenizer.eos_token_id

        self.bos_token = self.tokenizer.convert_ids_to_tokens([self.bos_token_id])[0]
        self.eos_token = self.tokenizer.convert_ids_to_tokens([self.eos_token_id])[0]
        print ('bos token is {}, eos token is {}'.format(self.bos_token, self.eos_token))

        self.all_sos_token_id_list = []
        for token in all_sos_token_list:
            one_id = self.tokenizer.convert_tokens_to_ids([token])[0]
            self.all_sos_token_id_list.append(one_id)
            print (self.tokenizer.convert_ids_to_tokens([one_id]))
        print (len(self.all_sos_token_id_list))
        self.all_eos_token_id_list = []
        for token in all_eos_token_list:
            one_id = self.tokenizer.convert_tokens_to_ids([token])[0]
            self.all_eos_token_id_list.append(one_id)
            print (self.tokenizer.convert_ids_to_tokens([one_id]))
        print (len(self.all_eos_token_id_list))

        if self.add_prefix:
            bs_prefix_text = 'translate dialogue to belief state:'
            self.bs_prefix_id = self.tokenizer.convert_tokens_to_ids(tokenizer.tokenize(bs_prefix_text))
        else:
            self.bs_prefix_id = []

        import json
        if data_mode == 'train':
            train_json_path = data_path_prefix + '/multiwoz-fine-processed-train.json'
            with open(train_json_path) as f:
                train_raw_data = json.load(f)

            self.train_data_ratio = train_data_ratio
            assert self.train_data_ratio > 0
            # few-shot learning
            if self.train_data_ratio < 1.0:
                print ('Few-shot training setup.')
                few_shot_num = int(len(train_raw_data) * self.train_data_ratio) + 1
                random.shuffle(train_raw_data)
                # randomly select a subset of training data
                train_raw_data = train_raw_data[:few_shot_num]
                print ('Number of training sessions is {}'.format(few_shot_num))

            print ('Tokenizing raw train data...')
            train_data_id_list = self.tokenize_raw_data(train_raw_data)
            self.train_data_list = self.flatten_data(train_data_id_list)
            # record training data
            self.train_id2session_dict = {}
            self.train_dial_id_list = []
            for item in self.train_data_list:
                one_item_id = item['dial_id']
                try:
                    self.train_id2session_dict[one_item_id].append(item)
                except KeyError:
                    self.train_dial_id_list.append(one_item_id)
                    self.train_id2session_dict[one_item_id] = [item]
            assert len(self.train_dial_id_list) == len(self.train_id2session_dict)
            self.train_num = len(self.train_data_list) 
        elif data_mode == 'test':
            train_raw_data = []
        else:
            raise Exception('Wrong Data Mode!!!')

        dev_json_path = data_path_prefix + '/multiwoz-fine-processed-dev.json'
        with open(dev_json_path) as f:
            dev_raw_data = json.load(f)
        print ('Tokenizing raw dev data...')
        dev_data_id_list = self.tokenize_raw_data(dev_raw_data)
        self.dev_data_list = self.flatten_data(dev_data_id_list)

        test_json_path = data_path_prefix + '/multiwoz-fine-processed-test.json'
        with open(test_json_path) as f:
            test_raw_data = json.load(f)
        print ('Tokenizing raw test data...')
        test_data_id_list = self.tokenize_raw_data(test_raw_data)
        self.test_data_list = self.flatten_data(test_data_id_list)

        print ('The size of raw train, dev and test sets are %d, %d and %d' % \
            (len(train_raw_data), len(dev_raw_data), len(test_raw_data)))

        self.dev_num, self.test_num = len(self.dev_data_list), len(self.test_data_list)
        if data_mode == 'train':
            print ('train turn number is %d, dev turn number is %d, test turn number is %d' % \
                (len(self.train_data_list), len(self.dev_data_list), len(self.test_data_list)))
            self.shuffle_mode = shuffle_mode
            self.ordering_train_data()
        else:
            pass