in ssl/real-dataset/loss/nt_xent.py [0:0]
def forward(self, zis, zjs, zs):
representations = torch.cat([zjs, zis], dim=0)
similarity_matrix = self.similarity_function(representations, representations)
# filter out the scores from the positive samples
l_pos = torch.diag(similarity_matrix, self.batch_size)
r_pos = torch.diag(similarity_matrix, -self.batch_size)
positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)
if self.exact_cov:
# 1 - sim = dist
r_neg = 1 - negatives
r_pos = 1 - positives
num_negative = negatives.size(1)
# Similarity matrix for unaugmented data.
if self.exact_cov_unaug_sim and zs is not None:
similarity_matrix2 = self.similarity_function(zs, zs)
negatives_unaug = similarity_matrix2[self.mask_samples_small].view(self.batch_size, -1)
r_neg_unaug = 1 - negatives_unaug
w = (-r_neg_unaug.detach() / self.temperature).exp()
# Duplicated four times.
w = torch.cat([w, w], dim=0)
w = torch.cat([w, w], dim=1)
else:
w = (-r_neg.detach() / self.temperature).exp()
w = w / (1 + w) / self.temperature / num_negative
# Then we construct the loss function.
w_pos = w.sum(dim=1, keepdim=True)
loss = (w_pos * r_pos - (w * r_neg).sum(dim=1)).mean()
loss_intra = self.beta * (w_pos * r_pos).mean()
else:
if self.add_one_in_neg:
all_ones = torch.ones(2 * self.batch_size, 1).to(self.device)
logits = torch.cat((positives, negatives, all_ones), dim=1)
else:
logits = torch.cat((positives, negatives), dim=1)
logits /= self.temperature
labels = torch.zeros(2 * self.batch_size).to(self.device).long()
loss = self.criterion(logits, labels)
# Make positive strong than negative to trigger an additional term.
loss_intra = -positives.sum() * self.beta / self.temperature
loss /= (1.0 + self.beta) * 2 * self.batch_size
loss_intra /= (1.0 + self.beta) * 2 * self.batch_size
return loss, loss_intra