def forward()

in touch_charts/models.py [0:0]


	def forward(self, gel, depth, ref_frame, empty, producing_sheet = False):
		# get initial data
		batch_size = ref_frame['pos'].shape[0]
		pos = ref_frame['pos'].cuda().view(batch_size, -1)
		rot_m = ref_frame['rot_M'].cuda().view(-1, 3, 3)

		# U-Net prediction
		# downscale the image
		x1 = self.inc(gel)
		x2 = self.down1(x1)
		x3 = self.down2(x2)
		x4 = self.down3(x3)
		x5 = self.down4(x4)
		# upscale the image
		x = self.up1(x5, x4)
		x = self.up2(x, x3)
		x = self.up3(x, x2)
		x = self.up4(x, x1)
		pred_depth =(self.outc(x))
		# scale the prediction
		pred_depth = F.sigmoid(pred_depth) * 0.1

		# we only want to use the points in the predicted point cloud if they correspond to pixels in the touch signal
		# which are "different" enough from the an untouched touch signal, otherwise the do not correspond to any
		# geometry of the object which is deforming the touch sensor's surface.
		diff = torch.sqrt((((gel.permute(0, 2, 3, 1) - empty.permute(0, 2, 3, 1)).view(batch_size, -1, 3)) **2).sum(dim = -1))
		useful_points = diff > 0.001
		# project the depth values into 3D points
		projected_depths = self.project_depth(pred_depth.squeeze(1), pos, rot_m).view(batch_size, -1, 3)

		pred_points = []
		for points, useful in zip(projected_depths, useful_points):
			# select only useful points
			orig_points = points.clone()
			points = points[useful]
			if points.shape[0] == 0:
				if producing_sheet:
					pred_points.append(torch.zeros((self.args.num_samples, 3)).cuda())
					continue
				else:
					points = orig_points

			# make the number of points in each element of a batch consistent
			while points.shape[0] < self.args.num_samples:
				points = torch.cat((points, points, points, points))
			perm = torch.randperm(points.shape[0])
			idx = perm[:self.args.num_samples]
			points = points[idx]
			pred_points.append(points)
		pred_points = torch.stack(pred_points)


		return pred_depth, pred_points