def get_eval_multilbl_logits()

in scripts/adapet/ADAPET/src/adapet.py [0:0]


    def get_eval_multilbl_logits(self, pet_mask_ids, batch_mask_idx, batch_list_lbl):
        '''
        Evaluate for labels with multiple tokens

        :param pet_mask_ids: [bs, max_seq_len ]
        :param batch_mask_idx: [bs, num_lbl, max_num_lbl_tok]
        :param list_lbl: [bs, num_lbl]
        :return:
        '''
        log_probs = []
        # Assume batch size 0
        list_lbl = batch_list_lbl[0]
        mask_idx = batch_mask_idx[0]

        if self.config.dataset.lower() == "generic": pet_mask_ids = pet_mask_ids.repeat(len(list_lbl), 1)

        for idx, lbl in enumerate(list_lbl):
            lbl_ids = self.tokenizer(lbl, add_special_tokens=False)["input_ids"]
            log_probabilities = []

            while True:
                masks = [(idx, tok_id) for idx, tok_id in zip(mask_idx[idx], lbl_ids) if tok_id != -100]
                if not masks:
                    break

                pet_rep = self.model(pet_mask_ids[idx:idx + 1], (pet_mask_ids[idx:idx + 1] > 0).long())[
                    0]  # [bs, max_seq_len]
                next_token_logits = pet_rep.softmax(dim=-1)[
                    0]  # The last indexing operation gets rid of batch dimension

                # Only implementing the 'default' non-autoregressive strategy for now
                mask_pos, masked_id = None, None
                max_prob = None
                for m_pos, m_id in masks:
                    m_prob = next_token_logits[m_pos][m_id].item()
                    if max_prob is None or m_prob > max_prob:
                        max_prob = m_prob
                        mask_pos, masked_id = m_pos, m_id

                log_probabilities.append(math.log(max_prob))
                pet_mask_ids[idx][mask_pos] = masked_id
                if isinstance(mask_pos, list):
                    tok_pos = mask_idx[idx].index(mask_pos)
                else:
                    tok_pos = torch.min(torch.nonzero(mask_idx[idx] == mask_pos)[0])
                lbl_ids[tok_pos] = -100

            log_probs.append(sum(log_probabilities))

        return torch.tensor([log_probs])