example/python-howto/data.py (23 lines of code) (raw):

# pylint: skip-file """ data iterator for mnist """ import sys import os # code to automatically download dataset curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.append(os.path.join(curr_path, "../../tests/python/common")) import get_data import mxnet as mx def mnist_iterator(batch_size, input_shape): """return train and val iterators for mnist""" # download data get_data.GetMNIST_ubyte() flat = False if len(input_shape) == 3 else True train_dataiter = mx.io.MNISTIter( image="data/train-images-idx3-ubyte", label="data/train-labels-idx1-ubyte", input_shape=input_shape, batch_size=batch_size, shuffle=True, flat=flat) val_dataiter = mx.io.MNISTIter( image="data/t10k-images-idx3-ubyte", label="data/t10k-labels-idx1-ubyte", input_shape=input_shape, batch_size=batch_size, flat=flat) return (train_dataiter, val_dataiter)