in scripts/adapet/ADAPET/src/data/RecordReader.py [0:0]
def prepare_pet_batch(self, batch, mode="PET1"):
'''
Prepare for train
:param batch:
:return:
'''
list_passage = batch["input"]["passage"]
list_question = batch["input"]["question"]
list_true_entity = batch["input"]["true_entity"]
list_false_entities = batch["input"]["false_entities"]
list_lbl = batch["output"]["lbl"]
bs = len(list_passage)
assert(bs == 1)
list_input_ids = []
list_mask_idx = np.ones((bs, self.config.max_num_lbl, self.get_num_lbl_tok())) * self.config.max_text_length - 1
list_lbl_choices = []
for b_idx, (p, q, te, fe, lbl) in enumerate(zip(list_passage, list_question, list_true_entity, list_false_entities, list_lbl)):
mask_txt_split_tuple = []
true_num_lbl_tok = self.get_lbl_num_lbl_tok(te)
max_num_lbl_tok = true_num_lbl_tok
for idx, wrong_enty in enumerate(fe):
num_lbl_tok = self.get_lbl_num_lbl_tok(wrong_enty)
if num_lbl_tok > max_num_lbl_tok:
max_num_lbl_tok = num_lbl_tok
txt_trim = -1
pattern = self.pet_patterns[self._pet_names.index(mode)]
for idx, txt_split in enumerate(pattern):
mask_txt_split_inp = txt_split.replace("[PASSAGE]", p).replace("[QUESTION]", q + " [SEP]").replace("[MASK] ", "[MASK] " * max_num_lbl_tok).replace("@highlight", "-")
mask_txt_split_tuple.append(mask_txt_split_inp)
# Trim the paragraph
if "[PASSAGE]" in txt_split:
txt_trim = idx
input_ids, mask_idx = tokenize_pet_txt(self.tokenizer, self.config, mask_txt_split_tuple[0],
mask_txt_split_tuple[1], mask_txt_split_tuple[2],
mask_txt_split_tuple[0], mask_txt_split_tuple[1],
mask_txt_split_tuple[2], txt_trim)
list_mask_idx[b_idx, 0, :true_num_lbl_tok] = range(mask_idx, mask_idx + true_num_lbl_tok)
for idx, wrong_enty in enumerate(fe):
num_lbl_tok = self.get_lbl_num_lbl_tok(wrong_enty)
list_mask_idx[b_idx, (idx+1), :num_lbl_tok] = range(mask_idx, mask_idx + num_lbl_tok)
list_input_ids.append(input_ids)
candidates = [te]
candidates.extend(fe)
list_lbl_choices.append(candidates)
return torch.tensor(list_input_ids).to(device), torch.tensor(list_mask_idx).to(device), list_lbl_choices