in utils/checkpointing.py [0:0]
def restart_from_checkpoint(ckp_path, args, run_variables=None, **kwargs):
"""
Re-start from checkpoint
"""
print("Found checkpoint at {}".format(ckp_path))
if args.distributed:
dist.barrier()
# configure map_location properly
map_location = {"cuda:%d" % 0: "cuda:%d" % args.rank}
# open checkpoint file
checkpoint = torch.load(ckp_path, map_location=map_location)
else:
checkpoint = torch.load(ckp_path, map_location="cpu")
# key is what to look for in the checkpoint file
# value is the object to load
# example: {'state_dict': model}
for key, value in kwargs.items():
if key in checkpoint and value is not None:
try:
# for compatibility with previous versions of augerino where
# width and min_values were 1d
if key == "state_dict":
if "module.aug.width" in checkpoint[key]:
if len(checkpoint[key]["module.aug.width"].size()) == 1:
print("extending the size of width")
checkpoint[key]["module.aug.width"] = checkpoint[key][
"module.aug.width"
].unsqueeze(0)
if "module.aug.min_values" in checkpoint[key]:
if len(checkpoint[key]["module.aug.min_values"].size()) == 1:
print("extending the size of min_values")
checkpoint[key]["module.aug.min_values"] = checkpoint[key][
"module.aug.min_values"
].unsqueeze(0)
msg = value.load_state_dict(checkpoint[key], strict=True)
print(msg)
except TypeError:
msg = value.load_state_dict(checkpoint[key])
print("=> loaded {} from checkpoint '{}'".format(key, ckp_path))
else:
print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
# re load variable important for the run
if run_variables is not None:
for var_name in run_variables:
if var_name in checkpoint:
run_variables[var_name] = checkpoint[var_name]