in models/moco.py [0:0]
def forward(self, im_q, im_k=None):
"""
Input:
im_q: a batch of query images
im_k: a batch of key images
Output:
logits, targets
"""
# for inference, query model only
if im_k is None:
feats = self.encoder_q(im_q)
return feats
# compute query features
q = self.encoder_q(im_q) # queries: NxC
q = self.encoder_q.fc(q)
q = nn.functional.normalize(q, dim=1)
# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder
# shuffle for making use of BN
if self.do_shuffle_bn:
im_k, idx_unshuffle = dist_utils.batch_shuffle_ddp(im_k)
k = self.encoder_k(im_k) # keys: NxC
k = self.encoder_k.fc(k)
k = nn.functional.normalize(k, dim=1)
# undo shuffle
if self.do_shuffle_bn:
k = dist_utils.batch_unshuffle_ddp(k, idx_unshuffle)
# compute logits
# Einstein sum is more intuitive
# positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
# apply temperature
logits /= self.T
# labels: positive key indicators
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
# dequeue and enqueue
self._dequeue_and_enqueue(k)
return logits, labels