in scripts/adapet/ADAPET/src/data/COPAReader.py [0:0]
def prepare_pet_batch(self, batch, mode="PET1"):
'''
Prepare for train
:param batch:
:return:
'''
list_premise = batch["input"]["premise"]
list_choice1 = batch["input"]["choice1"]
list_choice2 = batch["input"]["choice2"]
list_question = batch["input"]["question"]
list_lbl = batch["output"]["lbl"]
list_input_ids = []
bs = len(batch["input"]["choice2"])
list_mask_idx = np.ones((bs, self.num_lbl, self.config.max_num_lbl_tok)) * self.config.max_text_length - 1
list_lbl_choices = []
for b_idx, (p, c1, c2, ques, lbl) in enumerate(zip(list_premise, list_choice1, list_choice2, list_question, list_lbl)):
mask_txt_split_tuple = []
trimmed_c1 = c1[:-1]
trimmed_c2 = c2[:-1]
c1_num_lbl_tok = self.get_lbl_num_lbl_tok(trimmed_c1)
c2_num_lbl_tok = self.get_lbl_num_lbl_tok(trimmed_c2)
if c1_num_lbl_tok < c2_num_lbl_tok:
trimmed_c1 = " ".join(trimmed_c1.split(" ") + [self.tokenizer.pad_token] * (c2_num_lbl_tok - c1_num_lbl_tok))
if c2_num_lbl_tok < c1_num_lbl_tok:
trimmed_c2 = " ".join(trimmed_c2.split(" ") + [self.tokenizer.pad_token] * (c1_num_lbl_tok - c2_num_lbl_tok))
max_num_c_lbl_tok = max(c1_num_lbl_tok, c2_num_lbl_tok)
txt_trim = -1
if ques == "cause":
pet_pvps = self.pet_patterns_cause
elif ques == "effect":
pet_pvps = self.pet_patterns_effect
pattern = pet_pvps[self._pet_names.index(mode)]
for idx, txt_split in enumerate(pattern):
mask_txt_split_inp = txt_split.replace("[PREMISE]", p[:-1]).replace("[CHOICE1]", c1[:-1]).replace("[CHOICE2]", c2[:-1]).replace("[MASK]",
"[MASK] " * max_num_c_lbl_tok)
mask_txt_split_tuple.append(mask_txt_split_inp)
# Trim the paragraph
if "[PREMISE]" in txt_split:
txt_trim = idx
input_ids, mask_idx = tokenize_pet_txt(self.tokenizer, self.config, mask_txt_split_tuple[0],
mask_txt_split_tuple[1], mask_txt_split_tuple[2],
mask_txt_split_tuple[0], mask_txt_split_tuple[1],
mask_txt_split_tuple[2], txt_trim)
list_input_ids.append(input_ids)
list_mask_idx[b_idx, 0, :max_num_c_lbl_tok] = range(mask_idx, mask_idx + max_num_c_lbl_tok)
list_mask_idx[b_idx, 1, :max_num_c_lbl_tok] = range(mask_idx, mask_idx + max_num_c_lbl_tok)
list_lbl_choices.append([trimmed_c1, trimmed_c2])
return torch.tensor(list_input_ids).to(device), torch.tensor(list_mask_idx).to(device), list_lbl_choices