def get_multilbl_logits()

in scripts/adapet/ADAPET/src/adapet.py [0:0]


    def get_multilbl_logits(self, pet_mask_ids, mask_idx, batch_list_lbl ):
        '''
        Get decoupled label logits at mask positions for multiple mask tokens

        :param batch:
        :return:
        '''
        bs = pet_mask_ids.shape[0]
        num_lbl, max_num_lbl_tok = mask_idx.shape[1:]
        lbl_ids = np.zeros((bs, self.num_lbl, self.config.max_num_lbl_tok)) # [bs, num_lbl, max_num_lbl_tok]

        # Get lbl ids for multi token labels
        for i, list_lbl in enumerate(batch_list_lbl):
            for j, lbl in enumerate(list_lbl):
                i_j_lbl_ids = self.tokenizer(lbl, add_special_tokens=False)["input_ids"]
                lbl_ids[i, j, :len(i_j_lbl_ids)] = i_j_lbl_ids[:min(self.config.max_num_lbl_tok, len(i_j_lbl_ids))]
        lbl_ids = torch.from_numpy(lbl_ids).to(device)

        # Get probability for each vocab token at the mask position
        pet_logits = self.model(pet_mask_ids, (pet_mask_ids>0).long())[0] # [bs, max_seq_len, vocab_size]
        vs = pet_logits.shape[-1]
        mask_idx = mask_idx.reshape(bs, num_lbl*self.config.max_num_lbl_tok)
        pet_rep_mask_ids_logit = torch.gather(pet_logits, 1, mask_idx[:, :, None].repeat(1, 1, vs).long()) # [bs, num_lbl * max_num_lbl_tok, vs]
        pet_rep_mask_ids_logit = pet_rep_mask_ids_logit.reshape(bs, num_lbl, self.config.max_num_lbl_tok, vs) # [bs, num_lbl, max_num_lbl_tok, vs]
        pet_rep_mask_ids_prob = pet_rep_mask_ids_logit.softmax(dim=-1)

        # Compute logit for the lbl tokens at the mask position
        lbl_ids_expd = lbl_ids[...,None] # [bs, num_lbl, max_num_lbl_tok, 1]
        pet_rep_mask_ids_lbl_logit = torch.gather(pet_rep_mask_ids_prob, 3, lbl_ids_expd.long()).squeeze(3)  # [bs, num_lbl, max_num_lbl_tok]

        if self.config.dataset.lower() == 'fewglue/wsc':
            masked_pet_rep_mask_ids_lbl_logit = pet_rep_mask_ids_lbl_logit * (mask_idx!=(pet_mask_ids.shape[-1] - 1)).unsqueeze(1).long()
        else:
            masked_pet_rep_mask_ids_lbl_logit = pet_rep_mask_ids_lbl_logit * (lbl_ids>0).long()

        return masked_pet_rep_mask_ids_lbl_logit, lbl_ids, None