def prepare_pet_batch()

in scripts/adapet/ADAPET/src/data/WSCReader.py [0:0]


    def prepare_pet_batch(self, batch, mode="PET1"):
        '''
        Prepare for train

        :param batch:
        :return:
        '''

        list_text = batch["input"]["text"]

        list_pronoun = batch["input"]["pronoun"]
        list_noun = batch["input"]["noun"]
        list_lbl = batch["output"]["lbl"]

        list_input_ids = []
        bs = len(batch["input"]["text"])
        list_mask_idx = np.ones((bs, self.num_lbl, self.config.max_num_lbl_tok)) * self.config.max_text_length - 1
        list_lbl_choices = []

        for b_idx, (t, p, n, lbl) in enumerate(zip(list_text, list_pronoun, list_noun, list_lbl)):
            mask_txt_split_tuple = []
            noun_num_lbl_tok = self.get_lbl_num_lbl_tok(n)
            num_lbl_tok = min(noun_num_lbl_tok + random.randint(0,3), self.config.max_num_lbl_tok) # random.randint(0,3)
            txt_trim = -1
            pattern = self.pet_patterns[self._pet_names.index(mode)]

            for idx, txt_split in enumerate(pattern):
                mask_txt_split_inp = txt_split.replace("[TEXT]", t).replace("[NNP]", p).replace("[MASK]", "[MASK] " * num_lbl_tok)
                mask_txt_split_tuple.append(mask_txt_split_inp)

                # Trim the paragraph
                if "[TEXT]" in txt_split:
                    txt_trim = idx

            input_ids, mask_idx = tokenize_pet_txt(self.tokenizer, self.config, mask_txt_split_tuple[0],
                                                   mask_txt_split_tuple[1], mask_txt_split_tuple[2],
                                                   mask_txt_split_tuple[0], mask_txt_split_tuple[1],
                                                   mask_txt_split_tuple[2], txt_trim)
            list_input_ids.append(input_ids)
            list_mask_idx[b_idx, 0, :num_lbl_tok] = range(mask_idx, mask_idx + num_lbl_tok)

            lbl_mask = n.split() + [self.tokenizer.pad_token] * (num_lbl_tok - noun_num_lbl_tok)
            list_lbl_choices.append([' '.join(lbl_mask)])


        return torch.tensor(list_input_ids).to(device), torch.tensor(list_mask_idx).to(device), list_lbl_choices