in src/util.py [0:0]
def load(model_class, dir_path, opt, reset_params=False):
epoch_path = os.path.realpath(dir_path)
optimizer_path = os.path.join(epoch_path, "optimizer.pth.tar")
logger.info("Loading %s" % epoch_path)
model = model_class.from_pretrained(epoch_path)
model = model.to(opt.device)
logger.info("loading checkpoint %s" %optimizer_path)
checkpoint = torch.load(optimizer_path, map_location=opt.device)
opt_checkpoint = checkpoint["opt"]
step = checkpoint["step"]
if "best_eval_metric" in checkpoint:
best_eval_metric = checkpoint["best_eval_metric"]
else:
best_eval_metric = checkpoint["best_dev_em"]
if not reset_params:
optimizer, scheduler = set_optim(opt_checkpoint, model)
scheduler.load_state_dict(checkpoint["scheduler"])
optimizer.load_state_dict(checkpoint["optimizer"])
else:
optimizer, scheduler = set_optim(opt, model)
return model, optimizer, scheduler, opt_checkpoint, step, best_eval_metric