in utils/interpolation_base.py [0:0]
def interpolate(self, super_idx=-1):
lr = self.param.lr
optimizer = torch.optim.LBFGS(
self.interp_module.parameters(), lr=lr, line_search_fn="strong_wolfe"
)
self.interp_module.train()
for it in range(self.param.num_it):
def closure():
if torch.is_grad_enabled():
optimizer.zero_grad()
E, Elist = self.interp_module()
if E.requires_grad:
E.backward()
if self.param.log:
if super_idx >= 0:
print(
"Super {:02d}, It {:03d}, E: {:.5f}".format(
super_idx, it, Elist[0]
)
)
else:
print("It {:03d}, E: {:.5f}".format(it, Elist[0]))
return E
optimizer.step(closure)
self.energy = self.interp_module.eval()