in loss_fn/simclr_infonce.py [0:0]
def forward(self, output: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
"""
Args:
output: BxC
target: BxC or BxKxC <-- In case of MIL NCE, K is the number of
positives for each batch element.
Following https://github.com/google-research/simclr/blob/master/objective.py
"""
# Normalize first, before the gather -- so that all the features I get
# are normalized
output = nn.functional.normalize(output, dim=-1, p=2)
target = nn.functional.normalize(target, dim=-1, p=2)
# To be consistent with MIL-NCE input, convert K to batch dim,
# and repeat the output to same value for each repeated target
elt_for_back_loss = 0
if target.ndim == 3:
num_matching = target.size(1)
target_flat = target.reshape((-1, target.size(-1)))
# Keep the first one for the back loss
target = target[:, elt_for_back_loss]
else:
num_matching = 1
target_flat = target
# Gather all the outputs and all the targets
output_all = self.gather_embeddings(output)
target_flat_all = self.gather_embeddings(target_flat)
batch_size = output.size(0)
replica_id = utils.get_rank()
# -> (B, B_full * num_matching)
labels_onehot = torch.zeros((batch_size, output_all.size(0)),
dtype=output.dtype,
device=output.device)
extra_zeros = torch.zeros((batch_size, output_all.size(0)),
dtype=output.dtype,
device=output.device)
ones_diag = torch.eye(batch_size,
batch_size,
dtype=output.dtype,
device=output.device)
labels_onehot[:, replica_id * batch_size:(replica_id + 1) *
batch_size] = ones_diag
labels_onehot_interleaved = labels_onehot.repeat_interleave(
num_matching, dim=1)
# (B, C) * (B_full, C) -> (B, B_full)
logits_aa = torch.mm(output, output_all.t() / self.temperature)
# (B, C) * (B_full * num_matching, C) -> (B, B_full * num_matching)
logits_ab = torch.mm(output, target_flat_all.t() / self.temperature)
logits_aa = logits_aa - labels_onehot * LARGE_NUM
loss = self.criterion(
torch.cat([logits_ab, logits_aa], 1),
torch.cat([labels_onehot_interleaved, extra_zeros], 1))
if self.target_to_output_loss:
# Keep only the first prediction, since that is what I will incur
# reverse loss with
target_all = target_flat_all[elt_for_back_loss::num_matching]
logits_bb = torch.mm(target, target_all.t() / self.temperature)
logits_bb = logits_bb - labels_onehot * LARGE_NUM
logits_ba = torch.mm(target, output_all.t() / self.temperature)
loss = loss + self.criterion(
torch.cat([logits_ba, logits_bb], 1),
torch.cat([labels_onehot, extra_zeros], 1))
return loss