def get_pred()

in model/interpolation_net.py [0:0]


    def get_pred(self, shape_x, shape_y, update_corr=True):
        if update_corr:
            self.match(shape_x, shape_y)

        step_size = 1 / (self.param.num_timesteps + 1)
        timesteps = step_size + torch.arange(0, 1, step_size, device=device).unsqueeze(
            1
        ).unsqueeze(
            2
        )  # [T, 1, 1]
        timesteps_up = timesteps * (
            torch.as_tensor([0, 0, 0, 0, 0, 0, 1], device=device, dtype=torch.float)
            .unsqueeze(0)
            .unsqueeze(1)
        )  # [T, 1, 7]

        points_in = torch.cat(
            (
                shape_x.vert,
                torch.mm(self.Pi, shape_y.vert) - shape_x.vert,
                my_zeros((shape_x.vert.shape[0], 1)),
            ),
            dim=1,
        ).unsqueeze(
            0
        )  # [1, n, 7]
        points_in = points_in + timesteps_up

        edge_index = shape_x.get_edge_index()

        displacement = my_zeros([points_in.shape[0], points_in.shape[1], 3])
        for i in range(points_in.shape[0]):
            displacement[i, :, :] = self.rn_ec(points_in[i, :, :], edge_index)
        # the previous three lines used to support batchwise processing in torch-geometric but are now deprecated:
        # displacement = self.rn_ec(points_in, edge_index)  # [T, n, 3]

        point_pred_arr = shape_x.vert.unsqueeze(0) + displacement * timesteps
        point_pred_arr = point_pred_arr.permute([1, 2, 0])
        return point_pred_arr