in vision_charts/models.py [0:0]
def forward(self, img_occ, img_unocc, batch):
# initial data
batch_size = img_occ.shape[0]
cur_vertices = self.initial_positions.unsqueeze(0).expand(batch_size, -1, -1)
size_vision_charts = cur_vertices.shape[1]
# if using touch then append touch chart position to graph definition
if self.args.use_touch:
sheets = batch['sheets'].cuda().view(batch_size, -1, 3)
cur_vertices = torch.cat((cur_vertices,sheets), dim = 1 )
# cycle thorugh deformation
for _ in range(3):
vertex_features = cur_vertices.clone()
# add vision features
if self.args.use_occluded or self.args.use_unoccluded:
vert_img_features = self.img_encoder(img_occ, img_unocc, cur_vertices)
vertex_features = torch.cat((vert_img_features, vertex_features), dim=-1)
# add mask for touch charts
if self.args.use_touch:
vision_chart_mask = torch.ones(batch_size, size_vision_charts, 1).cuda() * 2 # flag corresponding to vision
touch_chart_mask = torch.FloatTensor(batch['successful']).cuda().unsqueeze(-1).expand(batch_size, 4 * self.args.num_grasps, 25)
touch_chart_mask = touch_chart_mask.contiguous().view(batch_size, -1, 1)
mask = torch.cat((vision_chart_mask, touch_chart_mask), dim=1)
vertex_features = torch.cat((vertex_features,mask), dim = -1)
# deform the vertex positions
vertex_positions = self.mesh_decoder(vertex_features, self.adj_info)
# avoid deforming the touch chart positions
vertex_positions[:, size_vision_charts:] = 0
cur_vertices = cur_vertices + vertex_positions
return cur_vertices