def forward()

in ssl/real-dataset/loss/nt_xent.py [0:0]


    def forward(self, zis, zjs, zs):
        representations = torch.cat([zjs, zis], dim=0)
        similarity_matrix = self.similarity_function(representations, representations)

        # filter out the scores from the positive samples
        l_pos = torch.diag(similarity_matrix, self.batch_size)
        r_pos = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
        negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)

        if self.exact_cov:
            # 1 - sim = dist
            r_neg = 1 - negatives
            r_pos = 1 - positives

            num_negative = negatives.size(1)

            # Similarity matrix for unaugmented data.
            if self.exact_cov_unaug_sim and zs is not None:
                similarity_matrix2 = self.similarity_function(zs, zs)
                negatives_unaug = similarity_matrix2[self.mask_samples_small].view(self.batch_size, -1)
                r_neg_unaug = 1 - negatives_unaug
                w = (-r_neg_unaug.detach() / self.temperature).exp() 
                # Duplicated four times. 
                w = torch.cat([w, w], dim=0)
                w = torch.cat([w, w], dim=1)
            else:
                w = (-r_neg.detach() / self.temperature).exp() 
            
            w = w / (1 + w) / self.temperature / num_negative
            # Then we construct the loss function. 
            w_pos = w.sum(dim=1, keepdim=True)
            loss = (w_pos * r_pos - (w * r_neg).sum(dim=1)).mean()
            loss_intra = self.beta * (w_pos * r_pos).mean()
        else:
            if self.add_one_in_neg:
                all_ones = torch.ones(2 * self.batch_size, 1).to(self.device)
                logits = torch.cat((positives, negatives, all_ones), dim=1)
            else:
                logits = torch.cat((positives, negatives), dim=1)

            logits /= self.temperature

            labels = torch.zeros(2 * self.batch_size).to(self.device).long()
            loss = self.criterion(logits, labels)

            # Make positive strong than negative to trigger an additional term. 
            loss_intra = -positives.sum() * self.beta / self.temperature
            loss /= (1.0 + self.beta) * 2 * self.batch_size
            loss_intra /= (1.0 + self.beta) * 2 * self.batch_size

        return loss, loss_intra