in pytorch_managed_spot_training_checkpointing/source_dir/cifar10.py [0:0]
def _save_checkpoint(model, optimizer, epoch, loss, args):
print("epoch: {} - loss: {}".format(epoch+1, loss))
checkpointing_path = args.checkpoint_path + '/checkpoint.pth'
print("Saving the Checkpoint: {}".format(checkpointing_path))
torch.save({
'epoch': epoch+1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, checkpointing_path)