in model/interpolation_net.py [0:0]
def get_pred(self, shape_x, shape_y, update_corr=True):
if update_corr:
self.match(shape_x, shape_y)
step_size = 1 / (self.param.num_timesteps + 1)
timesteps = step_size + torch.arange(0, 1, step_size, device=device).unsqueeze(
1
).unsqueeze(
2
) # [T, 1, 1]
timesteps_up = timesteps * (
torch.as_tensor([0, 0, 0, 0, 0, 0, 1], device=device, dtype=torch.float)
.unsqueeze(0)
.unsqueeze(1)
) # [T, 1, 7]
points_in = torch.cat(
(
shape_x.vert,
torch.mm(self.Pi, shape_y.vert) - shape_x.vert,
my_zeros((shape_x.vert.shape[0], 1)),
),
dim=1,
).unsqueeze(
0
) # [1, n, 7]
points_in = points_in + timesteps_up
edge_index = shape_x.get_edge_index()
displacement = my_zeros([points_in.shape[0], points_in.shape[1], 3])
for i in range(points_in.shape[0]):
displacement[i, :, :] = self.rn_ec(points_in[i, :, :], edge_index)
# the previous three lines used to support batchwise processing in torch-geometric but are now deprecated:
# displacement = self.rn_ec(points_in, edge_index) # [T, n, 3]
point_pred_arr = shape_x.vert.unsqueeze(0) + displacement * timesteps
point_pred_arr = point_pred_arr.permute([1, 2, 0])
return point_pred_arr