in trainer/loss.py [0:0]
def forward(self, video_features, text_features):
"""
Inputs shape (batch, embed_dim)
Args:
im: Visual embeddings (batch, embed_dim)
s: Text embeddings (batch, embed_dim)
Returns:
"""
batch_size = video_features.shape[0]
# Normalize features
video_features = nn.functional.normalize(video_features, dim=1)
text_features = nn.functional.normalize(text_features, dim=1)
# Inter-modality alignment
logits_per_vid = video_features @ text_features.t()
logits_per_text = text_features @ video_features.t()
# Intra-modality alignment
logits_clstr_vid = video_features @ video_features.t()
logits_clstr_txt = text_features @ text_features.t()
logits_per_vid /= self.temperature
logits_per_text /= self.temperature
logits_clstr_vid /= self.temperature
logits_clstr_txt /= self.temperature
positive_mask = self._get_positive_mask( video_features.shape[0])
negatives_vid = logits_clstr_vid * positive_mask
negatives_txt = logits_clstr_txt * positive_mask
vid_logits = torch.cat([logits_per_vid, self.negative_w * negatives_vid], dim=1)
txt_logits = torch.cat([logits_per_text, self.negative_w * negatives_txt], dim=1)
diag = np.eye(batch_size)
mask_vid = torch.from_numpy((diag)).cuda()
mask_txt = torch.from_numpy((diag)).cuda()
mask_neg_v = torch.zeros_like(negatives_vid)
mask_neg_t = torch.zeros_like(negatives_txt)
mask_v = torch.cat([mask_vid, mask_neg_v], dim=1)
mask_t = torch.cat([mask_txt, mask_neg_t], dim=1)
loss_i = self.compute_loss(vid_logits, mask_v)
loss_t = self.compute_loss(txt_logits, mask_t)
return ((loss_i.mean() + loss_t.mean()) ) / 2