in model.py [0:0]
def forward(self, img, spec, whichhead=0):
img_features = self.video_network(img).squeeze()
aud_features = self.audio_network(spec).squeeze()
if self.return_features:
return img_features, aud_features
if len(aud_features.shape) == 1:
aud_features = aud_features.unsqueeze(0)
if len(img_features.shape) == 1:
img_features = img_features.unsqueeze(0)
if self.hc == 1:
nce_img_features = self.mlp_v(img_features)
nce_aud_features = self.mlp_a(aud_features)
if self.norm_feat:
nce_img_features = F.normalize(nce_img_features, p=2, dim=1)
nce_aud_features = F.normalize(nce_aud_features, p=2, dim=1)
return nce_img_features, nce_aud_features
elif self.hc > 1:
# note: will return lists here.
outs1 = []
outs2 = []
for head in range(self.hc):
img_f = getattr(self, "mlp_v%d"%head)(img_features)
aud_f = getattr(self, "mlp_a%d"%head)(aud_features)
if self.norm_feat:
img_f = F.normalize(img_f, p=2, dim=1)
aud_f = F.normalize(aud_f, p=2, dim=1)
outs1.append(img_f)
outs2.append(aud_f)
return outs1, outs2