in criterions/nce_loss_moco.py [0:0]
def forward(self, output):
assert isinstance(
output, list
), "Model output should be a list of tensors. Got Type {}".format(type(output))
if self.normalize_embedding:
normalized_output1 = nn.functional.normalize(output[0], dim=1, p=2)
normalized_output2 = nn.functional.normalize(output[1], dim=1, p=2)
if self.other_queue:
normalized_output3 = nn.functional.normalize(output[2], dim=1, p=2)
normalized_output4 = nn.functional.normalize(output[3], dim=1, p=2)
# positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [normalized_output1, normalized_output2]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [normalized_output1, self.queue.clone().detach()])
# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
# apply temperature
logits /= self.T
if self.other_queue:
l_pos_p2i = torch.einsum('nc,nc->n', [normalized_output1, normalized_output4]).unsqueeze(-1)
l_neg_p2i = torch.einsum('nc,ck->nk', [normalized_output1, self.queue_other.clone().detach()])
logits_p2i = torch.cat([l_pos_p2i, l_neg_p2i], dim=1)
logits_p2i /= self.T
l_pos_i2p = torch.einsum('nc,nc->n', [normalized_output3, normalized_output2]).unsqueeze(-1)
l_neg_i2p = torch.einsum('nc,ck->nk', [normalized_output3, self.queue.clone().detach()])
logits_i2p = torch.cat([l_pos_i2p, l_neg_i2p], dim=1)
logits_i2p /= self.T
l_pos_other = torch.einsum('nc,nc->n', [normalized_output3, normalized_output4]).unsqueeze(-1)
l_neg_other = torch.einsum('nc,ck->nk', [normalized_output3, self.queue_other.clone().detach()])
logits_other = torch.cat([l_pos_other, l_neg_other], dim=1)
logits_other /= (self.T)
if self.other_queue:
self._dequeue_and_enqueue(normalized_output2, okeys=normalized_output4)
else:
self._dequeue_and_enqueue(normalized_output2)
labels = torch.zeros(
logits.shape[0], device=logits.device, dtype=torch.int64
)
loss_npid = self.xe_criterion(torch.squeeze(logits), labels)
loss_npid_other = torch.tensor(0)
loss_cmc_p2i = torch.tensor(0)
loss_cmc_i2p = torch.tensor(0)
if self.other_queue:
loss_cmc_p2i = self.xe_criterion(torch.squeeze(logits_p2i), labels)
loss_cmc_i2p = self.xe_criterion(torch.squeeze(logits_i2p), labels)
loss_npid_other = self.xe_criterion(torch.squeeze(logits_other), labels)
curr_loss = 0
for ltype in self.loss_list:
if ltype == "CMC":
curr_loss += loss_cmc_p2i * self.cmc0_w + loss_cmc_i2p * self.cmc1_w
elif ltype == "NPID":
curr_loss += loss_npid * self.npid0_w
curr_loss += loss_npid_other * self.npid1_w
else:
curr_loss = 0
curr_loss += loss_npid * self.npid0_w
loss = curr_loss
return loss, [loss_npid, loss_npid_other, loss_cmc_p2i, loss_cmc_i2p]