in model/interpolation_net.py [0:0]
def compute_loss(self, shape_x, shape_y, point_pred_arr, n_normalize=201.0):
E_x_0 = self.interp_energy.forward_single(
shape_x.vert, point_pred_arr[:, :, 0], shape_x
) + self.interp_energy.forward_single(
point_pred_arr[:, :, 0], shape_x.vert, shape_x
)
lambda_align = n_normalize / shape_x.vert.shape[0]
E_align = (
lambda_align
* self.param.lambd
* (
(torch.mm(self.Pi, shape_y.vert) - point_pred_arr[:, :, -1]).norm() ** 2
+ (
shape_y.vert - torch.mm(self.Pi_inv, point_pred_arr[:, :, -1])
).norm()
** 2
)
)
if shape_x.D is None:
E_geo = my_tensor(0)
elif self.param.lambd_geo == 0:
E_geo = my_tensor(0)
else:
E_geo = (
self.param.lambd_geo
* (
(
torch.mm(torch.mm(self.Pi, shape_y.D), self.Pi.transpose(0, 1))
- shape_x.D
)
** 2
).mean()
)
E = E_x_0 + E_align + E_geo
for i in range(self.param.num_timesteps):
E_x = self.interp_energy.forward_single(
point_pred_arr[:, :, i], point_pred_arr[:, :, i + 1], shape_x
)
E_y = self.interp_energy.forward_single(
point_pred_arr[:, :, i + 1], point_pred_arr[:, :, i], shape_x
)
E = E + E_x + E_y
return E, [E - E_align - E_geo, E_align, E_geo]