def restart_from_checkpoint()

in utils/checkpointing.py [0:0]


def restart_from_checkpoint(ckp_path, args, run_variables=None, **kwargs):
    """
    Re-start from checkpoint
    """
    print("Found checkpoint at {}".format(ckp_path))

    if args.distributed:
        dist.barrier()
        # configure map_location properly
        map_location = {"cuda:%d" % 0: "cuda:%d" % args.rank}
        # open checkpoint file
        checkpoint = torch.load(ckp_path, map_location=map_location)
    else:
        checkpoint = torch.load(ckp_path, map_location="cpu")

    # key is what to look for in the checkpoint file
    # value is the object to load
    # example: {'state_dict': model}
    for key, value in kwargs.items():
        if key in checkpoint and value is not None:
            try:
                # for compatibility with previous versions of augerino where
                # width and min_values were 1d
                if key == "state_dict":
                    if "module.aug.width" in checkpoint[key]:
                        if len(checkpoint[key]["module.aug.width"].size()) == 1:
                            print("extending the size of width")
                            checkpoint[key]["module.aug.width"] = checkpoint[key][
                                "module.aug.width"
                            ].unsqueeze(0)
                    if "module.aug.min_values" in checkpoint[key]:
                        if len(checkpoint[key]["module.aug.min_values"].size()) == 1:
                            print("extending the size of min_values")
                            checkpoint[key]["module.aug.min_values"] = checkpoint[key][
                                "module.aug.min_values"
                            ].unsqueeze(0)
                msg = value.load_state_dict(checkpoint[key], strict=True)
                print(msg)
            except TypeError:
                msg = value.load_state_dict(checkpoint[key])
            print("=> loaded {} from checkpoint '{}'".format(key, ckp_path))
        else:
            print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))

    # re load variable important for the run
    if run_variables is not None:
        for var_name in run_variables:
            if var_name in checkpoint:
                run_variables[var_name] = checkpoint[var_name]