in scripts/adapet/ADAPET/src/adapet.py [0:0]
def get_multilbl_logits(self, pet_mask_ids, mask_idx, batch_list_lbl ):
'''
Get decoupled label logits at mask positions for multiple mask tokens
:param batch:
:return:
'''
bs = pet_mask_ids.shape[0]
num_lbl, max_num_lbl_tok = mask_idx.shape[1:]
lbl_ids = np.zeros((bs, self.num_lbl, self.config.max_num_lbl_tok)) # [bs, num_lbl, max_num_lbl_tok]
# Get lbl ids for multi token labels
for i, list_lbl in enumerate(batch_list_lbl):
for j, lbl in enumerate(list_lbl):
i_j_lbl_ids = self.tokenizer(lbl, add_special_tokens=False)["input_ids"]
lbl_ids[i, j, :len(i_j_lbl_ids)] = i_j_lbl_ids[:min(self.config.max_num_lbl_tok, len(i_j_lbl_ids))]
lbl_ids = torch.from_numpy(lbl_ids).to(device)
# Get probability for each vocab token at the mask position
pet_logits = self.model(pet_mask_ids, (pet_mask_ids>0).long())[0] # [bs, max_seq_len, vocab_size]
vs = pet_logits.shape[-1]
mask_idx = mask_idx.reshape(bs, num_lbl*self.config.max_num_lbl_tok)
pet_rep_mask_ids_logit = torch.gather(pet_logits, 1, mask_idx[:, :, None].repeat(1, 1, vs).long()) # [bs, num_lbl * max_num_lbl_tok, vs]
pet_rep_mask_ids_logit = pet_rep_mask_ids_logit.reshape(bs, num_lbl, self.config.max_num_lbl_tok, vs) # [bs, num_lbl, max_num_lbl_tok, vs]
pet_rep_mask_ids_prob = pet_rep_mask_ids_logit.softmax(dim=-1)
# Compute logit for the lbl tokens at the mask position
lbl_ids_expd = lbl_ids[...,None] # [bs, num_lbl, max_num_lbl_tok, 1]
pet_rep_mask_ids_lbl_logit = torch.gather(pet_rep_mask_ids_prob, 3, lbl_ids_expd.long()).squeeze(3) # [bs, num_lbl, max_num_lbl_tok]
if self.config.dataset.lower() == 'fewglue/wsc':
masked_pet_rep_mask_ids_lbl_logit = pet_rep_mask_ids_lbl_logit * (mask_idx!=(pet_mask_ids.shape[-1] - 1)).unsqueeze(1).long()
else:
masked_pet_rep_mask_ids_lbl_logit = pet_rep_mask_ids_lbl_logit * (lbl_ids>0).long()
return masked_pet_rep_mask_ids_lbl_logit, lbl_ids, None