in scripts/adapet/ADAPET/src/data/COPAReader.py [0:0]
def prepare_pet_mlm_batch(self, batch, mode="PET1"):
'''
Prepare for train
:param batch:
:return:
'''
# Always use pattern 3 for COPA
mode = "PET3"
list_premise = batch["input"]["premise"]
list_choice1 = batch["input"]["choice1"]
list_choice2 = batch["input"]["choice2"]
list_question = batch["input"]["question"]
list_lbl = batch["output"]["lbl"]
bs = len(batch["input"]["question"])
prep_lbl = np.random.randint(self.num_lbl, size=bs)
tgt = torch.from_numpy(prep_lbl).long() == batch["output"]["lbl"]
list_orig_input_ids = []
list_masked_input_ids = []
for b_idx, (p, c1, c2, ques, lbl) in enumerate(zip(list_premise, list_choice1, list_choice2, list_question, list_lbl)):
txt_split_tuple = []
txt_trim = -1
if ques == "cause":
pet_pvps = self.pet_patterns_cause
elif ques == "effect":
pet_pvps = self.pet_patterns_effect
pattern = pet_pvps[self._pet_names.index(mode)]
if lbl.item() == 0:
lbl_choice = c1[:-1]
elif lbl.item() == 1:
lbl_choice = c2[:-1]
else:
raise ValueError("Invalid Lbl")
for idx, txt_split in enumerate(pattern):
txt_split_inp = txt_split.replace("[PREMISE]", p[:-1]).replace("[CHOICE1]", c1[:-1]).replace("[CHOICE2]", c2[:-1]).replace("[MASK]",
lbl_choice)
txt_split_tuple.append(txt_split_inp)
if lbl.item() == 0:
# Trim the paragraph
if "[PREMISE]" in txt_split:
txt_trim = idx
elif lbl.item() == 1:
# Trim the paragraph
if "[PREMISE]" in txt_split:
txt_trim = idx
else:
raise ValueError("Invalid Lbl")
orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt(self.tokenizer, self.config, txt_split_tuple[0], txt_split_tuple[1], txt_split_tuple[2], txt_trim)
list_orig_input_ids.append(orig_input_ids)
list_masked_input_ids.append(masked_input_ids)
return torch.tensor(list_orig_input_ids).to(device), torch.tensor(list_masked_input_ids).to(device), prep_lbl, tgt.to(device)