in classification/model.py [0:0]
def load_network(opt, model, which_epoch, save_dir=''):
if which_epoch > 0:
if opt.network_arch == 'mlp':
save_filename = '%s_classifier.pth' % (which_epoch)
else:
save_filename = '%s_linear_classifier.pth' % (which_epoch)
if not save_dir:
save_dir = opt.save_dir
else:
if opt.network_arch == 'mlp':
save_filename = get_model_list(save_dir, 'classifier')
else:
save_filename = get_model_list(save_dir, 'linear_classifier')
if not save_filename:
print("file does not exist")
exit()
save_path = os.path.join(save_dir, save_filename)
if not os.path.isfile(save_path):
print('%s not exists yet!' % save_path)
exit()
else:
try:
model.load_state_dict(torch.load(save_path))
except:
print("cannot load pretrained model")
exit()
if which_epoch > 0:
print('Resume from iteration %d' % which_epoch)
return which_epoch
else:
iterations = int(save_filename.split('_')[0])
print('Resume from iteration %d' % iterations)
return iterations