def _save_checkpoint()

in pytorch_managed_spot_training_checkpointing/source_dir/cifar10.py [0:0]


def _save_checkpoint(model, optimizer, epoch, loss, args):
    print("epoch: {} - loss: {}".format(epoch+1, loss))
    checkpointing_path = args.checkpoint_path + '/checkpoint.pth'
    print("Saving the Checkpoint: {}".format(checkpointing_path))
    torch.save({
        'epoch': epoch+1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        }, checkpointing_path)