in scripts/adapet/ADAPET/src/adapet.py [0:0]
def get_eval_wsc_logits(self, pet_mask_ids, batch_mask_idx, batch_list_lbl):
'''
Get logits using from generated probs
Code adapted from: https://github.com/timoschick/pet/blob/271910ebd4c30a4e0f8aaba39a153ae3d5822e22/pet/task_helpers.py#L453-L519
:param batch:
:param batch_mask_idx: [bs,][num_lbl][num_lbl_tok]
:return:
'''
# Assume batch size 0
list_lbl = batch_list_lbl[0]
mask_idx = batch_mask_idx[0]
while True:
mask_positions = [
idx for idx, input_id in enumerate(pet_mask_ids[0]) if input_id == self.tokenizer.mask_token_id
]
if not mask_positions: # there are no masks left to process, we are doneå
input_ids = pet_mask_ids[0].detach().cpu().tolist()
output_actual = self.tokenizer.decode([
input_id for idx, input_id in enumerate(input_ids)
if idx in mask_idx and input_id not in self.tokenizer.all_special_ids
])
output_expected = list_lbl[0]
# transform both outputs as described in the T5 paper
output_actual = output_actual.lower().strip()
output_actual = [w for w in re.split('[^a-zA-Z]', output_actual) if w]
output_expected = output_expected.lower().strip()
output_expected = [w for w in re.split('[^a-zA-Z]', output_expected) if w]
# compare outputs
if all(x in output_expected for x in output_actual) or all(
x in output_actual for x in output_expected):
return torch.tensor([[0, 1]])
return torch.tensor([[1, 0]])
outputs = self.model(pet_mask_ids, (pet_mask_ids > 0).long())
next_token_logits = outputs[0]
next_token_logits = next_token_logits.softmax(dim=2)
next_token_logits = next_token_logits[0].detach().cpu().numpy()
most_confident = ()
most_confident_score = -1
for mask_position in mask_positions:
ntl = next_token_logits[mask_position]
top_token_id = np.argmax(ntl)
top_score = ntl[top_token_id]
if top_score > most_confident_score:
most_confident_score = top_score
most_confident = (mask_position, top_token_id)
pet_mask_ids[0][most_confident[0]] = most_confident[1]