def get_decoupled_label_loss()

in scripts/adapet/ADAPET/src/adapet.py [0:0]


    def get_decoupled_label_loss(self, batch):
        '''
        Get decoupled label loss

        :param batch:
        :return:
        '''

        pet_mask_ids, mask_idx, list_lbl = self.dataset_reader.prepare_batch(batch, self.get_pattern())
        lbl = batch["output"]["lbl"].to(device)

        # Datasets where the label has more than 1 token
        if isinstance(list_lbl[0], list):
            lbl_logits, lbl_ids, _ = self.get_multilbl_logits(pet_mask_ids, mask_idx,
                                                              list_lbl)  # [bs, num_lbl, max_num_lbl_tok]
            if "wsc" in self.config.dataset.lower():
                reshape_lbl_logits = lbl_logits.reshape(-1)  # [bs * num_lbl * max_num_lbl_tok]
                reshape_lbl = torch.ones_like(reshape_lbl_logits)
                real_mask = lbl_logits > 0

            else:
                # Removing tokens that are common across choices
                same_words_ids = torch.stack([reduce(lambda x, y: (x == y) * y, lbl_logit) for lbl_logit in lbl_logits],
                                             dim=0)
                mask_same_words = (1 - (same_words_ids > 0).long()).repeat(1, lbl_logits.shape[1],
                                                                           1)  # [bs, num_lbl, max_num_lbl_tok]
                real_mask = mask_same_words * (lbl_ids > 0)

                # Applying the mask to the lbl_logits
                lbl_logits = lbl_logits * mask_same_words  # [bs, num_lbl, max_num_lbl_tok]
                reshape_lbl_logits = lbl_logits.reshape(-1)  # [bs * num_lbl * max_num_lbl_tok]

                with torch.no_grad():
                    lkup_lbl = self.lbl_idx_lkup(lbl.long())  # [bs, num_lbl]
                reshape_lbl = lkup_lbl[:, :, None].repeat(1, 1, self.config.max_num_lbl_tok).reshape(-1)

            full_sup_loss = self.loss(reshape_lbl_logits, reshape_lbl)  # [bs * num_lbl * max_num_lbl_tok]
            full_sup_loss = full_sup_loss.reshape(lbl_logits.shape)

            pet_disc_loss = torch.sum(full_sup_loss * real_mask) / torch.sum(real_mask)

        # Datasets where the label is 1 token
        else:
            # Get lbl logits
            lbl_logits = self.get_single_logits(pet_mask_ids, mask_idx, list_lbl) # [bs, num_lbl]
            reshape_lbl_logits = lbl_logits.reshape(-1) # [bs*num_lbl, ]

            # lbl is 1 at true_lbl idx, and 0 otherwise
            with torch.no_grad():
                lkup_lbl = self.lbl_idx_lkup(lbl)  # [bs, num_lbl]
            reshape_lbl = lkup_lbl.reshape(-1) # [bs*num_lbl]

            pet_disc_loss = torch.mean(self.loss(reshape_lbl_logits, reshape_lbl))

        return pet_disc_loss