in model/interpolation_net.py [0:0]
def load_chkpt(self, ckpt_path):
ckpt = torch.load(ckpt_path, map_location=device)
self.i_epoch = ckpt["i_epoch"]
self.interp_module.load_state_dict(ckpt["interp_module"])
if "par" in ckpt:
self.interp_module.param.from_dict(ckpt["par"])
self.interp_module.param.print_self()
if "optimizer_state_dict" in ckpt:
self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
self.interp_module.train()