def interpolate()

in utils/interpolation_base.py [0:0]


    def interpolate(self, super_idx=-1):

        lr = self.param.lr
        optimizer = torch.optim.Adam(self.interp_module.parameters(), lr=lr)

        self.interp_module.train()

        E = 0

        for it in range(self.param.num_it):
            optimizer.zero_grad()
            E, Elist = self.interp_module()
            E.backward()

            self.interp_module.mul_with_inv_hessian()

            optimizer.step()

            if self.param.log:
                if self.param.log:
                    if super_idx >= 0:
                        print(
                            "Super {:02d}, It {:03d}, E: {:.5f}".format(
                                super_idx, it, Elist[0]
                            )
                        )
                    else:
                        print("It {:03d}, E: {:.5f}".format(it, Elist[0]))

        self.energy = self.interp_module.eval()

        return E.detach()