def load_model()

in PyTorchClassification/data_loader_cv.py [0:0]


def load_model(filename, useGPU=True):
    """
    Loads a model from a checkpoint.
    """

    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))

        if useGPU:
            cuda_device = torch.cuda.current_device()
            checkpoint = torch.load(filename, map_location=lambda storage, loc: storage.cuda(cuda_device))
        else:
            checkpoint = torch.load(filename, map_location=lambda storage, loc: storage)

        start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0
        best_prec1 = checkpoint['best_prec1'] if 'best_prec1' in checkpoint else 0
        best_prec3 = checkpoint['best_prec3'] if 'best_prec3' in checkpoint else 0
        best_prec5 = checkpoint['best_prec5'] if 'best_prec5' in checkpoint else 0

        state_dict = checkpoint['state_dict']
        classnames = checkpoint['classnames']
        model_type = checkpoint['model_type']

        print('Loaded %d classes' % len(classnames))

        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            module = k[0:7] # check for 'module.' of dataparallel
            name = k[7:] # remove 'module.' of dataparallel            

            if k[:7] == 'module.':
                k = k[7:]
            if k[:2] == '1.':
                k = k[2:]
            if k[:6] == 'model.':
                k = k[6:]

            new_state_dict[k] = v

            #print("%s" % (k))

        model_dict = new_state_dict        
        optimizer_dict = checkpoint['optimizer'] if 'optimizer' in checkpoint else None

        print("=> loaded checkpoint '{}' (epoch {})"
                .format(filename, start_epoch))

        data.best_prec1 = best_prec1
        data.best_prec3 = best_prec3
        data.best_prec5 = best_prec5
        data.start_epoch = start_epoch
        data.classnames = classnames
        data.model_dict = model_dict
        data.optimizer_dict = optimizer_dict
        data.model_type = model_type

        return data 

    else:
        print("=> no checkpoint found at '{}'".format(filename))