in pytorch_managed_spot_training_checkpointing/source_dir/cifar10.py [0:0]
def _load_checkpoint(model, optimizer, args):
print("--------------------------------------------")
print("Checkpoint file found!")
print("Loading Checkpoint From: {}".format(args.checkpoint_path + '/checkpoint.pth'))
checkpoint = torch.load(args.checkpoint_path + '/checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch_number = checkpoint['epoch']
loss = checkpoint['loss']
print("Checkpoint File Loaded - epoch_number: {} - loss: {}".format(epoch_number, loss))
print('Resuming training from epoch: {}'.format(epoch_number+1))
print("--------------------------------------------")
return model, optimizer, epoch_number