def main()

in examples/mnist/pytorch_example.py [0:0]


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='Petastorm MNIST Example')
    default_dataset_url = 'file://{}'.format(DEFAULT_MNIST_DATA_PATH)
    parser.add_argument('--dataset-url', type=str,
                        default=default_dataset_url, metavar='S',
                        help='hdfs:// or file:/// URL to the MNIST petastorm dataset '
                             '(default: %s)' % default_dataset_url)
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--all-epochs', action='store_true', default=False,
                        help='train all epochs before testing accuracy/loss')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device('cuda' if use_cuda else 'cpu')

    model = Net().to(device)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

    # Configure loop and Reader epoch for illustrative purposes.
    # Typical training usage would use the `all_epochs` approach.
    #
    if args.all_epochs:
        # Run training across all the epochs before testing for accuracy
        loop_epochs = 1
        reader_epochs = args.epochs
    else:
        # Test training accuracy after each epoch
        loop_epochs = args.epochs
        reader_epochs = 1

    transform = TransformSpec(_transform_row, removed_fields=['idx'])

    # Instantiate each petastorm Reader with a single thread, shuffle enabled, and appropriate epoch setting
    for epoch in range(1, loop_epochs + 1):
        with DataLoader(make_reader('{}/train'.format(args.dataset_url), num_epochs=reader_epochs,
                                    transform_spec=transform),
                        batch_size=args.batch_size) as train_loader:
            train(model, device, train_loader, args.log_interval, optimizer, epoch)
        with DataLoader(make_reader('{}/test'.format(args.dataset_url), num_epochs=reader_epochs,
                                    transform_spec=transform),
                        batch_size=args.test_batch_size) as test_loader:
            test(model, device, test_loader)