def prepare_eval_pet_batch()

in scripts/adapet/ADAPET/src/data/RecordReader.py [0:0]


    def prepare_eval_pet_batch(self, batch, mode="PET1"):
        '''
        Prepare for train
        :param batch:
        :return:
        '''

        list_passage = batch["input"]["passage"]
        list_question = batch["input"]["question"]
        list_candidates = batch["input"]["candidate_entity"]

        bs = len(list_passage)

        assert(bs == 1)

        list_input_ids = []
        list_mask_idx = []

        for b_idx, (p, q, cands) in enumerate(zip(list_passage, list_question, list_candidates)):
            pattern = self.pet_patterns[self._pet_names.index(mode)]
            list_mask_idx_lbls = []

            for cand in cands:
                num_cnd_tok = len(self.tokenizer(cand, add_special_tokens=False)["input_ids"])
                mask_txt_split_tuple = []
                txt_trim = -1
                for idx, txt_split in enumerate(pattern):
                    mask_txt_split_inp = txt_split.replace("[PASSAGE]", p).replace("[QUESTION]", q + " [SEP]").replace("[MASK]", "[MASK] " * num_cnd_tok).replace("@highlight", "-")
                    mask_txt_split_tuple.append(mask_txt_split_inp)

                    # Trim the paragraph
                    if "[PASSAGE]" in txt_split:
                        txt_trim = idx

                input_ids, mask_idx = tokenize_pet_txt(self.tokenizer, self.config, mask_txt_split_tuple[0],
                                                       mask_txt_split_tuple[1], mask_txt_split_tuple[2],
                                                       mask_txt_split_tuple[0], mask_txt_split_tuple[1],
                                                       mask_txt_split_tuple[2], txt_trim)
                list_input_ids.append(input_ids)
                list_mask_idx_lbl = list(range(mask_idx, mask_idx + num_cnd_tok))
                list_mask_idx_lbls.append(list_mask_idx_lbl)

            list_mask_idx.append(list_mask_idx_lbls)

        return torch.tensor(list_input_ids).to(device), list_mask_idx, list_candidates