def compute_loss()

in model/interpolation_net.py [0:0]


    def compute_loss(self, shape_x, shape_y, point_pred_arr, n_normalize=201.0):

        E_x_0 = self.interp_energy.forward_single(
            shape_x.vert, point_pred_arr[:, :, 0], shape_x
        ) + self.interp_energy.forward_single(
            point_pred_arr[:, :, 0], shape_x.vert, shape_x
        )

        lambda_align = n_normalize / shape_x.vert.shape[0]
        E_align = (
            lambda_align
            * self.param.lambd
            * (
                (torch.mm(self.Pi, shape_y.vert) - point_pred_arr[:, :, -1]).norm() ** 2
                + (
                    shape_y.vert - torch.mm(self.Pi_inv, point_pred_arr[:, :, -1])
                ).norm()
                ** 2
            )
        )

        if shape_x.D is None:
            E_geo = my_tensor(0)
        elif self.param.lambd_geo == 0:
            E_geo = my_tensor(0)
        else:
            E_geo = (
                self.param.lambd_geo
                * (
                    (
                        torch.mm(torch.mm(self.Pi, shape_y.D), self.Pi.transpose(0, 1))
                        - shape_x.D
                    )
                    ** 2
                ).mean()
            )

        E = E_x_0 + E_align + E_geo

        for i in range(self.param.num_timesteps):
            E_x = self.interp_energy.forward_single(
                point_pred_arr[:, :, i], point_pred_arr[:, :, i + 1], shape_x
            )
            E_y = self.interp_energy.forward_single(
                point_pred_arr[:, :, i + 1], point_pred_arr[:, :, i], shape_x
            )

            E = E + E_x + E_y

        return E, [E - E_align - E_geo, E_align, E_geo]