in scripts/adapet/ADAPET/src/data/GenericReader.py [0:0]
def prepare_pet_batch_multi_token_label(self, batch, list_list_txt):
'''
Prepare pet batch when the labels only consist of 1 token
'''
bs = len(batch["input"]["TEXT1"])
list_input_ids = []
list_mask_idx = np.ones((bs, self.num_lbl, self.get_num_lbl_tok())) * self.config.max_text_length - 1
txt_trim = 1
for b_idx in range(bs):
mask_txt_split_tuple = []
for idx, txt_split in enumerate(self.pattern):
for i in range(1, self.text_ctr):
txt_split = txt_split.replace("[TEXT%d]" % i, list_list_txt[i - 1][b_idx])
txt_split = txt_split.replace("[LBL]", self.tokenizer.mask_token * self.get_num_lbl_tok())
mask_txt_split_tuple.append(txt_split)
input_ids, mask_idx = tokenize_pet_txt(self.tokenizer, self.config, mask_txt_split_tuple[0],
mask_txt_split_tuple[1], mask_txt_split_tuple[2],
mask_txt_split_tuple[0], mask_txt_split_tuple[1],
mask_txt_split_tuple[2], txt_trim)
list_input_ids.append(input_ids)
max_num_lbl_tok = 0
for idx, lbl in enumerate(self.label):
num_lbl_tok = self.get_lbl_num_lbl_tok(lbl)
if num_lbl_tok > max_num_lbl_tok:
max_num_lbl_tok = num_lbl_tok
for i in range(self.num_lbl):
list_mask_idx[b_idx, i, :max_num_lbl_tok] = range(mask_idx, mask_idx + max_num_lbl_tok)
list_label = []
for i in range(bs):
list_label.append(self.label)
return torch.tensor(list_input_ids).to(device), torch.tensor(list_mask_idx).to(device).long(), list_label