in models/networks.py [0:0]
def forward(self, x, visual_feat):
audio_conv1feature = self.audionet_convlayer1(x)
audio_conv2feature = self.audionet_convlayer2(audio_conv1feature)
audio_conv3feature = self.audionet_convlayer3(audio_conv2feature)
audio_conv4feature = self.audionet_convlayer4(audio_conv3feature)
audio_conv5feature = self.audionet_convlayer5(audio_conv4feature)
visual_feat = self.conv1x1(visual_feat)
visual_feat = visual_feat.view(visual_feat.shape[0], -1, 1, 1) #flatten visual feature
visual_feat = visual_feat.repeat(1, 1, audio_conv5feature.shape[-2], audio_conv5feature.shape[-1]) #tile visual feature
audioVisual_feature = torch.cat((visual_feat, audio_conv5feature), dim=1)
audio_upconv1feature = self.audionet_upconvlayer1(audioVisual_feature)
audio_upconv2feature = self.audionet_upconvlayer2(torch.cat((audio_upconv1feature, audio_conv4feature), dim=1))
audio_upconv3feature = self.audionet_upconvlayer3(torch.cat((audio_upconv2feature, audio_conv3feature), dim=1))
audio_upconv4feature = self.audionet_upconvlayer4(torch.cat((audio_upconv3feature, audio_conv2feature), dim=1))
mask_prediction = self.audionet_upconvlayer5(torch.cat((audio_upconv4feature, audio_conv1feature), dim=1)) * 2 - 1
return mask_prediction