in utils/interpolation_base.py [0:0]
def mul_with_inv_hessian(self):
num_t = self.param.num_timesteps
n_vert = self.shape_x.get_vert_shape()[0]
hess_1d = self.energy.get_hessian(self.shape_x)
central_diff_diags = -np.ones([3, num_t])
central_diff_diags[1, :] = 2
central_diff = lil_matrix(
spdiags(central_diff_diags, np.array([-1, 0, 1]), num_t, num_t)
)
central_diff[[0, num_t - 1], :] = 0
boundary_cond = lil_matrix((num_t, num_t))
boundary_cond[0, 0] = 1
boundary_cond[num_t - 1, num_t - 1] = 1
hess = csr_matrix(
kron(central_diff, hess_1d) + kron(boundary_cond, eye(n_vert))
)
grad_hess = spsolve(
hess, self.vert_sequence.grad.view(-1, 3).to(device_cpu).detach().cpu()
)
self.vert_sequence.grad = torch.tensor(
grad_hess, dtype=torch.float32, device=device
).view_as(self.vert_sequence)
self.vert_sequence.grad = self.vert_sequence.grad.clone()