in PyTorchClassification/data_loader.py [0:0]
def load_model(filename, useGPU=True):
"""
Loads a model from a checkpoint.
"""
if os.path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
if useGPU:
cuda_device = torch.cuda.current_device()
checkpoint = torch.load(filename, map_location=lambda storage, loc: storage.cuda(cuda_device))
else:
checkpoint = torch.load(filename, map_location=lambda storage, loc: storage)
start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0
best_prec1 = checkpoint['best_prec1'] if 'best_prec1' in checkpoint else 0
best_prec3 = checkpoint['best_prec3'] if 'best_prec3' in checkpoint else 0
best_prec5 = checkpoint['best_prec5'] if 'best_prec5' in checkpoint else 0
state_dict = checkpoint['state_dict']
classnames = checkpoint['classnames']
model_type = checkpoint['model_type']
print('Loaded %d classes' % len(classnames))
new_state_dict = OrderedDict()
for k, v in state_dict.items():
module = k[0:7] # check for 'module.' of dataparallel
name = k[7:] # remove 'module.' of dataparallel
if k[:7] == 'module.':
k = k[7:]
if k[:2] == '1.':
k = k[2:]
if k[:6] == 'model.':
k = k[6:]
new_state_dict[k] = v
#print("%s" % (k))
model_dict = new_state_dict
optimizer_dict = checkpoint['optimizer'] if 'optimizer' in checkpoint else None
print("=> loaded checkpoint '{}' (epoch {})"
.format(filename, start_epoch))
data.best_prec1 = best_prec1
data.best_prec3 = best_prec3
data.best_prec5 = best_prec5
data.start_epoch = start_epoch
data.classnames = classnames
data.model_dict = model_dict
data.optimizer_dict = optimizer_dict
data.model_type = model_type
return data
else:
print("=> no checkpoint found at '{}'".format(filename))