in vision_charts/models.py [0:0]
def forward(self, img_occ, img_unocc, cur_vertices):
# double size due to legacy decision
if self.args.use_unoccluded:
x = torch.cat((img_unocc, img_unocc), dim = 1)
elif self.args.use_occluded:
x = torch.cat((img_occ, img_occ), dim=1)
else:
x = torch.cat((img_occ, img_unocc), dim=1)
features = []
layer_selections = [len(self.layers) - 1 - (i+1)*self.args.num_img_layers for i in range(3)]
for e, layer in enumerate(self.layers):
if x.shape[-1] < self.args.size_img_ker:
break
x = layer(x)
# collect feature maps
if e in layer_selections:
features.append(x)
features.append(x)
# get vertex features from selected feature maps
vert_image_features = self.pooling(features, cur_vertices)
return vert_image_features