in touch_charts/models.py [0:0]
def forward(self, gel, depth, ref_frame, empty, producing_sheet = False):
# get initial data
batch_size = ref_frame['pos'].shape[0]
pos = ref_frame['pos'].cuda().view(batch_size, -1)
rot_m = ref_frame['rot_M'].cuda().view(-1, 3, 3)
# U-Net prediction
# downscale the image
x1 = self.inc(gel)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
# upscale the image
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
pred_depth =(self.outc(x))
# scale the prediction
pred_depth = F.sigmoid(pred_depth) * 0.1
# we only want to use the points in the predicted point cloud if they correspond to pixels in the touch signal
# which are "different" enough from the an untouched touch signal, otherwise the do not correspond to any
# geometry of the object which is deforming the touch sensor's surface.
diff = torch.sqrt((((gel.permute(0, 2, 3, 1) - empty.permute(0, 2, 3, 1)).view(batch_size, -1, 3)) **2).sum(dim = -1))
useful_points = diff > 0.001
# project the depth values into 3D points
projected_depths = self.project_depth(pred_depth.squeeze(1), pos, rot_m).view(batch_size, -1, 3)
pred_points = []
for points, useful in zip(projected_depths, useful_points):
# select only useful points
orig_points = points.clone()
points = points[useful]
if points.shape[0] == 0:
if producing_sheet:
pred_points.append(torch.zeros((self.args.num_samples, 3)).cuda())
continue
else:
points = orig_points
# make the number of points in each element of a batch consistent
while points.shape[0] < self.args.num_samples:
points = torch.cat((points, points, points, points))
perm = torch.randperm(points.shape[0])
idx = perm[:self.args.num_samples]
points = points[idx]
pred_points.append(points)
pred_points = torch.stack(pred_points)
return pred_depth, pred_points