def calc_loss()

in trainers/catex.py [0:0]


    def calc_loss(self, logits, label, image=None, image_features=None, text_features=None, return_norm=False):
        nb, ncls = logits.shape

        # 1. classification
        if self.cfg.TRAINER.OOD_PROMPT and self.epoch >= self.cfg.TRAINER.START_EPOCH \
            and self.is_large_ID() and self.id_pool.ready():

            if not return_norm:
                image_features = F.normalize(image_features, p=2, dim=1)
            ood_text_features = self.model.get_text_features(ood_prompt=True)

            if self.cfg.TRAINER.OOD_PROMPT_CE_LOSS:
                logits = self.model.get_logits(image_features, torch.cat((text_features, ood_text_features)))
                logits[torch.arange(nb), label+ncls] = -10.  # generally -inf

        loss = 2. * self.ce_criterion(logits, label)

        # 2. prompt perturbation
        perturbed_text_features = None
        if self.cfg.TRAINER.OOD_PROMPT and self.cfg.TRAINER.ID_PERTURB_LOSS and self.epoch >= self.cfg.TRAINER.START_EPOCH:
            with torch.no_grad():
                perturbed_text_features = self.model.get_text_features(perturb=random.choice(perturb_methods))
            loss += 0.1 * self.model.calc_prompt_loss(text_features, perturbed_text_features)

        # 3. outlier exposure
        assert text_features is not None
        if self.is_large_ID():
            if self.id_pool.ready() and self.cfg.TRAINER.OOD_PROMPT and self.epoch >= self.cfg.TRAINER.START_EPOCH:
                if logits.size(0) < self.id_pool.queue.size(0):
                    cls_mask = torch.unique(label).cpu()
                else:
                    cls_mask = None

                if self.cfg.TRAINER.OOD_ANCHOR:
                    if self.cfg.TRAINER.ID_PERTURB_LOSS and False:
                        perturbed_text_features = self.model.get_text_features(perturb=random.choice(perturb_methods))

                        logit_scale = self.model.logit_scale.exp()
                        id_pos_sim = (image_features * text_features[label]).sum(dim=-1) * logit_scale
                        id_neg_sim = (image_features * perturbed_text_features[label]).sum(dim=-1) * logit_scale
                        loss += F.cross_entropy(torch.stack((id_pos_sim, id_neg_sim), dim=1), 
                                                torch.zeros((len(id_pos_sim),), dtype=torch.long, device=self.device)) * 0.5
                    elif perturbed_text_features is None:
                        with torch.no_grad():
                            perturbed_text_features = self.model.get_text_features(perturb=random.choice(perturb_methods))
                    text_anchors = torch.stack((text_features, perturbed_text_features), dim=1).detach()
                else:
                    text_anchors = None
                ood_features, ood_labels = self.id_pool.gen_ood(anchors=text_anchors, device=self.device, cls_mask=cls_mask)

                if self.cfg.TRAINER.OOD_OE_LOSS:
                    ood_logits = self.model.get_logits(ood_features, text_features, logit_scale=1.)
                    loss += 0.5 * -(ood_logits.mean(1) - torch.logsumexp(ood_logits, dim=1)).mean()

                if self.cfg.TRAINER.OOD_PROMPT:
                    # ood_text_features = self.model.get_text_features(ood_prompt=True)

                    if self.cfg.TRAINER.OOD_PROMPT_ORTH:
                        assert self.cfg.TRAINER.OOD_PROMPT_NUM > 1
                        all_ood_text_features = self.model.get_all_ood_text_features()
                        # (1000,5,512) x (1000,512,5) -> (1000,5,5)
                        ood_sim_matrix = torch.bmm(all_ood_text_features, all_ood_text_features.transpose(1,2))
                        ood_text_num = ood_sim_matrix.shape[-1]
                        zrange = torch.arange(ood_text_num)
                        ood_sim_matrix[:, zrange, zrange] = 0.
                        loss += 0.1 * ood_sim_matrix.mean()

                    if self.cfg.TRAINER.OOD_PROMPT_CE_LOSS:
                        ood_logits = self.model.get_logits(ood_features,
                                                        torch.cat((ood_text_features, text_features)))
                        ood_logits[torch.arange(ood_logits.shape[0]), ood_labels+ncls] = -10.  # generally -inf
                        loss += 0.5 * self.ce_criterion(ood_logits, ood_labels)
                        
                    if self.cfg.TRAINER.OOD_PROMPT_MARGIN_LOSS:
                        if self.cfg.TRAINER.OOD_PROMPT_MARGIN_SOFT_LOSS:
                            logit_scale = self.model.logit_scale.exp()
                        else:
                            logit_scale = 1.

                        id_pos_sim = (image_features * text_features[label]).sum(dim=-1) * logit_scale
                        id_neg_sim = (image_features * ood_text_features[label]).sum(dim=-1) * logit_scale
                        ood_pos_sim = (ood_features * ood_text_features[ood_labels]).sum(dim=-1) * logit_scale
                        ood_neg_sim = (ood_features * text_features[ood_labels]).sum(dim=-1) * logit_scale

                        # id_pos_sim = (image_features @ text_features.T).max(dim=-1)[0] * logit_scale
                        # id_neg_sim = (image_features @ ood_text_features.T).max(dim=-1)[0] * logit_scale
                        # ood_pos_sim = (ood_features @ ood_text_features.T).max(dim=-1)[0] * logit_scale
                        # ood_neg_sim = (ood_features @ text_features.T).max(dim=-1)[0] * logit_scale

                        if self.cfg.TRAINER.OOD_PROMPT_MARGIN_SOFT_LOSS:
                            loss += F.cross_entropy(torch.stack((id_pos_sim, id_neg_sim), dim=1), 
                                                    torch.zeros((len(id_pos_sim),), dtype=torch.long, device=self.device)) + \
                                    F.cross_entropy(torch.stack((ood_pos_sim, ood_neg_sim), dim=1), 
                                                    torch.zeros((len(ood_pos_sim),), dtype=torch.long, device=self.device))
                        else:
                            loss += (id_neg_sim - id_pos_sim).relu().mean() + (ood_neg_sim - ood_pos_sim).relu().mean()

        else:
            ood_data, _ = next(self.ood_loader.__iter__())
            ood_data = ood_data.to(self.device)
            ood_logits = self.model(ood_data) #/ self.model.logit_scale.exp()
            loss += 0.1 * -(ood_logits.mean(1) - torch.logsumexp(ood_logits, dim=1)).mean()

        return loss