def prepare_pet_mlm_batch()

in scripts/adapet/ADAPET/src/data/RecordReader.py [0:0]


    def prepare_pet_mlm_batch(self, batch, mode="PET1"):
        '''
        Prepare for train
        :param batch:
        :return:
        '''

        list_passage = batch["input"]["passage"]
        list_question = batch["input"]["question"]
        list_true_entity = batch["input"]["true_entity"]
        list_false_entities = batch["input"]["false_entities"]
        list_lbl = batch["output"]["lbl"]

        bs = len(batch["input"]["passage"])

        prep_lbl = np.random.randint(self.num_lbl, size=bs)
        tgt = torch.from_numpy(prep_lbl).long() == batch["output"]["lbl"]

        list_orig_input_ids = []
        list_masked_input_ids = []

        for b_idx, (p, q, te, fe, lbl) in enumerate(zip(list_passage, list_question, list_true_entity, list_false_entities, list_lbl)):
            txt_split_tuple = []

            true_num_lbl_tok = self.get_lbl_num_lbl_tok(te)
            max_num_lbl_tok = true_num_lbl_tok
            for idx, wrong_enty in enumerate(fe):
                num_lbl_tok = self.get_lbl_num_lbl_tok(wrong_enty)
                if num_lbl_tok > max_num_lbl_tok:
                    max_num_lbl_tok = num_lbl_tok

            txt_trim = -1
            pattern = self.pet_patterns[self._pet_names.index(mode)]

            for idx, txt_split in enumerate(pattern):
                txt_split_inp = txt_split.replace("[PASSAGE]", p).replace("[QUESTION]", q + " [SEP]").replace("@highlight", "-")
                txt_split_tuple.append(txt_split_inp)

                # Trim the paragraph
                if "[PASSAGE]" in txt_split:
                    txt_trim = idx

            orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt(self.tokenizer, self.config, txt_split_tuple[0], txt_split_tuple[1], txt_split_tuple[2], txt_trim)
            list_orig_input_ids.append(orig_input_ids)
            list_masked_input_ids.append(masked_input_ids)

        return torch.tensor(list_orig_input_ids).to(device),  torch.tensor(list_masked_input_ids).to(device), prep_lbl, tgt.to(device)