def prepare_eval_pet_batch()

in scripts/adapet/ADAPET/src/data/COPAReader.py [0:0]


    def prepare_eval_pet_batch(self, batch, mode="PET1"):
        '''
        Prepare for train

        :param batch:
        :return:
        '''
        list_premise = batch["input"]["premise"]
        list_choice1 = batch["input"]["choice1"]
        list_choice2 = batch["input"]["choice2"]
        list_question = batch["input"]["question"]
        list_lbl = batch["output"]["lbl"]

        list_input_ids = []
        bs = len(batch["input"]["choice2"])
        assert bs == 1, "Evaluation is done only for batch size 1 for COPA"
        # list_mask_idx = np.ones((bs, self.num_lbl, self.config.max_num_lbl_tok)) * self.config.max_text_length - 1
        list_lbl_choices = []

        list_mask_idx = []
        for b_idx, (p, c1, c2, ques, lbl) in enumerate(zip(list_premise, list_choice1, list_choice2, list_question, list_lbl)):
            c1_num_lbl_tok = len(self.tokenizer(c1[:-1], add_special_tokens=False)["input_ids"])
            c2_num_lbl_tok = len(self.tokenizer(c2[:-1], add_special_tokens=False)["input_ids"])
            # import ipdb; ipdb.set_trace()
            num_lbl_toks = [c1_num_lbl_tok, c2_num_lbl_tok]
            list_mask_idx_lbls = []
            if ques == "cause":
                pet_pvps = self.pet_patterns_cause
            elif ques == "effect":
                pet_pvps = self.pet_patterns_effect
            pattern = pet_pvps[self._pet_names.index(mode)]

            for lbl_idx, num_lbl_tok in enumerate(num_lbl_toks):
                mask_txt_split_tuple = []
                txt_trim = -1
                for idx, txt_split in enumerate(pattern):
                    mask_txt_split_inp = txt_split.replace("[PREMISE]", p[:-1]).replace("[CHOICE1]", c1[:-1]).replace("[CHOICE2]", c2[:-1]).replace("[MASK]",
                                                                                                        "[MASK] " * num_lbl_tok)
                    mask_txt_split_tuple.append(mask_txt_split_inp)

                    # Trim the paragraph
                    if "[CHOICE1]" in txt_split:
                        txt_trim = idx

                input_ids, mask_idx = tokenize_pet_txt(self.tokenizer, self.config, mask_txt_split_tuple[0],
                                                       mask_txt_split_tuple[1], mask_txt_split_tuple[2],
                                                       mask_txt_split_tuple[0], mask_txt_split_tuple[1],
                                                       mask_txt_split_tuple[2], txt_trim)
                list_input_ids.append(input_ids)
                list_mask_idx_lbl = list(range(mask_idx, mask_idx + num_lbl_tok))
                list_mask_idx_lbls.append(list_mask_idx_lbl)

            list_mask_idx.append(list_mask_idx_lbls)
            list_lbl_choices.append([c1[:-1], c2[:-1]])

        return torch.tensor(list_input_ids).to(device), list_mask_idx, list_lbl_choices