def prepare_pet_batch_multi_token_label()

in scripts/adapet/ADAPET/src/data/GenericReader.py [0:0]


    def prepare_pet_batch_multi_token_label(self, batch, list_list_txt):
        '''
        Prepare pet batch when the labels only consist of 1 token

        '''

        bs = len(batch["input"]["TEXT1"])

        list_input_ids = []
        list_mask_idx = np.ones((bs, self.num_lbl, self.get_num_lbl_tok())) * self.config.max_text_length - 1
        txt_trim = 1

        for b_idx in range(bs):
            mask_txt_split_tuple = []

            for idx, txt_split in enumerate(self.pattern):
                for i in range(1, self.text_ctr):
                    txt_split = txt_split.replace("[TEXT%d]" % i, list_list_txt[i - 1][b_idx])
                txt_split = txt_split.replace("[LBL]", self.tokenizer.mask_token * self.get_num_lbl_tok())
                mask_txt_split_tuple.append(txt_split)

            input_ids, mask_idx = tokenize_pet_txt(self.tokenizer, self.config, mask_txt_split_tuple[0],
                                                   mask_txt_split_tuple[1], mask_txt_split_tuple[2],
                                                   mask_txt_split_tuple[0], mask_txt_split_tuple[1],
                                                   mask_txt_split_tuple[2], txt_trim)
            list_input_ids.append(input_ids)

            max_num_lbl_tok = 0
            for idx, lbl in enumerate(self.label):
                num_lbl_tok = self.get_lbl_num_lbl_tok(lbl)
                if num_lbl_tok > max_num_lbl_tok:
                    max_num_lbl_tok = num_lbl_tok

            for i in range(self.num_lbl):
                list_mask_idx[b_idx, i, :max_num_lbl_tok] = range(mask_idx, mask_idx + max_num_lbl_tok)

        list_label = []
        for i in range(bs):
            list_label.append(self.label)

        return torch.tensor(list_input_ids).to(device), torch.tensor(list_mask_idx).to(device).long(), list_label