in examples/mnist/pytorch_example.py [0:0]
def main():
# Training settings
parser = argparse.ArgumentParser(description='Petastorm MNIST Example')
default_dataset_url = 'file://{}'.format(DEFAULT_MNIST_DATA_PATH)
parser.add_argument('--dataset-url', type=str,
default=default_dataset_url, metavar='S',
help='hdfs:// or file:/// URL to the MNIST petastorm dataset '
'(default: %s)' % default_dataset_url)
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--all-epochs', action='store_true', default=False,
help='train all epochs before testing accuracy/loss')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device('cuda' if use_cuda else 'cpu')
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
# Configure loop and Reader epoch for illustrative purposes.
# Typical training usage would use the `all_epochs` approach.
#
if args.all_epochs:
# Run training across all the epochs before testing for accuracy
loop_epochs = 1
reader_epochs = args.epochs
else:
# Test training accuracy after each epoch
loop_epochs = args.epochs
reader_epochs = 1
transform = TransformSpec(_transform_row, removed_fields=['idx'])
# Instantiate each petastorm Reader with a single thread, shuffle enabled, and appropriate epoch setting
for epoch in range(1, loop_epochs + 1):
with DataLoader(make_reader('{}/train'.format(args.dataset_url), num_epochs=reader_epochs,
transform_spec=transform),
batch_size=args.batch_size) as train_loader:
train(model, device, train_loader, args.log_interval, optimizer, epoch)
with DataLoader(make_reader('{}/test'.format(args.dataset_url), num_epochs=reader_epochs,
transform_spec=transform),
batch_size=args.test_batch_size) as test_loader:
test(model, device, test_loader)