in utils/interpolation_base.py [0:0]
def interpolate(self, super_idx=-1):
lr = self.param.lr
optimizer = torch.optim.Adam(self.interp_module.parameters(), lr=lr)
self.interp_module.train()
E = 0
for it in range(self.param.num_it):
optimizer.zero_grad()
E, Elist = self.interp_module()
E.backward()
self.interp_module.mul_with_inv_hessian()
optimizer.step()
if self.param.log:
if self.param.log:
if super_idx >= 0:
print(
"Super {:02d}, It {:03d}, E: {:.5f}".format(
super_idx, it, Elist[0]
)
)
else:
print("It {:03d}, E: {:.5f}".format(it, Elist[0]))
self.energy = self.interp_module.eval()
return E.detach()