def get_eval_wsc_logits()

in scripts/adapet/ADAPET/src/adapet.py [0:0]


    def get_eval_wsc_logits(self, pet_mask_ids, batch_mask_idx, batch_list_lbl):
        '''
        Get logits using from generated probs
        Code adapted from: https://github.com/timoschick/pet/blob/271910ebd4c30a4e0f8aaba39a153ae3d5822e22/pet/task_helpers.py#L453-L519

        :param batch:
        :param batch_mask_idx: [bs,][num_lbl][num_lbl_tok]
        :return:
        '''

        # Assume batch size 0
        list_lbl = batch_list_lbl[0]
        mask_idx = batch_mask_idx[0]

        while True:
            mask_positions = [
                idx for idx, input_id in enumerate(pet_mask_ids[0]) if input_id == self.tokenizer.mask_token_id
            ]
            if not mask_positions:  # there are no masks left to process, we are doneå
                input_ids = pet_mask_ids[0].detach().cpu().tolist()
                output_actual = self.tokenizer.decode([
                    input_id for idx, input_id in enumerate(input_ids)
                    if idx in mask_idx and input_id not in self.tokenizer.all_special_ids
                ])

                output_expected = list_lbl[0]

                # transform both outputs as described in the T5 paper
                output_actual = output_actual.lower().strip()
                output_actual = [w for w in re.split('[^a-zA-Z]', output_actual) if w]
                output_expected = output_expected.lower().strip()
                output_expected = [w for w in re.split('[^a-zA-Z]', output_expected) if w]

                # compare outputs
                if all(x in output_expected for x in output_actual) or all(
                        x in output_actual for x in output_expected):
                    return torch.tensor([[0, 1]])
                return torch.tensor([[1, 0]])

            outputs = self.model(pet_mask_ids, (pet_mask_ids > 0).long())
            next_token_logits = outputs[0]
            next_token_logits = next_token_logits.softmax(dim=2)
            next_token_logits = next_token_logits[0].detach().cpu().numpy()

            most_confident = ()
            most_confident_score = -1

            for mask_position in mask_positions:
                ntl = next_token_logits[mask_position]
                top_token_id = np.argmax(ntl)
                top_score = ntl[top_token_id]
                if top_score > most_confident_score:
                    most_confident_score = top_score
                    most_confident = (mask_position, top_token_id)

            pet_mask_ids[0][most_confident[0]] = most_confident[1]