def forward()

in vision_charts/models.py [0:0]


	def forward(self, img_occ, img_unocc, batch):
		# initial data
		batch_size = img_occ.shape[0]
		cur_vertices = self.initial_positions.unsqueeze(0).expand(batch_size, -1, -1)
		size_vision_charts = cur_vertices.shape[1]

		# if using touch then append touch chart position to graph definition
		if self.args.use_touch:
			sheets = batch['sheets'].cuda().view(batch_size, -1, 3)
			cur_vertices = torch.cat((cur_vertices,sheets), dim = 1 )

		# cycle thorugh deformation
		for _ in range(3):
			vertex_features = cur_vertices.clone()
			# add vision features
			if self.args.use_occluded or self.args.use_unoccluded:
				vert_img_features = self.img_encoder(img_occ, img_unocc, cur_vertices)
				vertex_features = torch.cat((vert_img_features, vertex_features), dim=-1)
			# add mask for touch charts
			if self.args.use_touch:
				vision_chart_mask = torch.ones(batch_size, size_vision_charts, 1).cuda() * 2 # flag corresponding to vision
				touch_chart_mask = torch.FloatTensor(batch['successful']).cuda().unsqueeze(-1).expand(batch_size, 4 * self.args.num_grasps, 25)
				touch_chart_mask = touch_chart_mask.contiguous().view(batch_size, -1, 1)
				mask = torch.cat((vision_chart_mask, touch_chart_mask), dim=1)
				vertex_features = torch.cat((vertex_features,mask), dim = -1)

			# deform the vertex positions
			vertex_positions = self.mesh_decoder(vertex_features, self.adj_info)
			# avoid deforming the touch chart positions
			vertex_positions[:, size_vision_charts:] = 0
			cur_vertices = cur_vertices + vertex_positions

		return cur_vertices