def forward()

in criterions/nce_loss_moco.py [0:0]


    def forward(self, output):
        assert isinstance(
            output, list
        ), "Model output should be a list of tensors. Got Type {}".format(type(output))
        
        if self.normalize_embedding:
            normalized_output1 = nn.functional.normalize(output[0], dim=1, p=2)
            normalized_output2 = nn.functional.normalize(output[1], dim=1, p=2)
            if self.other_queue:
                normalized_output3 = nn.functional.normalize(output[2], dim=1, p=2)
                normalized_output4 = nn.functional.normalize(output[3], dim=1, p=2)

        # positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [normalized_output1, normalized_output2]).unsqueeze(-1)
        
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [normalized_output1, self.queue.clone().detach()])
        
        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)
        
        # apply temperature
        logits /= self.T

        if self.other_queue:
            
            l_pos_p2i = torch.einsum('nc,nc->n', [normalized_output1, normalized_output4]).unsqueeze(-1)
            l_neg_p2i = torch.einsum('nc,ck->nk', [normalized_output1, self.queue_other.clone().detach()])
            logits_p2i = torch.cat([l_pos_p2i, l_neg_p2i], dim=1)
            logits_p2i /= self.T

            
            l_pos_i2p = torch.einsum('nc,nc->n', [normalized_output3, normalized_output2]).unsqueeze(-1)
            l_neg_i2p = torch.einsum('nc,ck->nk', [normalized_output3, self.queue.clone().detach()])
            logits_i2p = torch.cat([l_pos_i2p, l_neg_i2p], dim=1)
            logits_i2p /= self.T

            
            l_pos_other = torch.einsum('nc,nc->n', [normalized_output3, normalized_output4]).unsqueeze(-1)
            l_neg_other = torch.einsum('nc,ck->nk', [normalized_output3, self.queue_other.clone().detach()])
            logits_other = torch.cat([l_pos_other, l_neg_other], dim=1)
            logits_other /= (self.T)
            
        if self.other_queue:
            self._dequeue_and_enqueue(normalized_output2, okeys=normalized_output4)
        else:
            self._dequeue_and_enqueue(normalized_output2)

        
        labels = torch.zeros(
            logits.shape[0], device=logits.device, dtype=torch.int64
        )
        
        loss_npid = self.xe_criterion(torch.squeeze(logits), labels)

        loss_npid_other = torch.tensor(0)
        loss_cmc_p2i = torch.tensor(0)
        loss_cmc_i2p = torch.tensor(0)
        
        if self.other_queue:
            loss_cmc_p2i = self.xe_criterion(torch.squeeze(logits_p2i), labels)
            loss_cmc_i2p = self.xe_criterion(torch.squeeze(logits_i2p), labels)
            loss_npid_other = self.xe_criterion(torch.squeeze(logits_other), labels)
            
            curr_loss = 0
            for ltype in self.loss_list:
                if ltype == "CMC":
                    curr_loss += loss_cmc_p2i * self.cmc0_w + loss_cmc_i2p * self.cmc1_w
                elif ltype == "NPID":
                    curr_loss += loss_npid * self.npid0_w
                    curr_loss += loss_npid_other * self.npid1_w
        else:
            curr_loss = 0
            curr_loss += loss_npid * self.npid0_w
                        
        loss = curr_loss

        return loss, [loss_npid, loss_npid_other, loss_cmc_p2i, loss_cmc_i2p]