def mnist_data_to_petastorm_dataset()

in examples/mnist/generate_petastorm_mnist.py [0:0]


def mnist_data_to_petastorm_dataset(download_dir, output_url, spark_master=None, parquet_files_count=1,
                                    mnist_data=None):
    """Converts a directory with MNIST data into a petastorm dataset.

    Data files are as specified in http://yann.lecun.com/exdb/mnist/:
        * train-images-idx3-ubyte.gz:  training set images (9912422 bytes)
        * train-labels-idx1-ubyte.gz:  training set labels (28881 bytes)
        * t10k-images-idx3-ubyte.gz:   test set images (1648877 bytes)
        * t10k-labels-idx1-ubyte.gz:   test set labels (4542 bytes)

    The images and labels and stored in the IDX file format for vectors and multidimensional matrices of
    various numerical types, as defined in the same URL.

    :param download_dir: the path to where the MNIST data will be downloaded.
    :param output_url: the location where your dataset will be written to. Should be a url: either
      file://... or hdfs://...
    :param spark_master: A master parameter used by spark session builder. Use default value (None) to use system
      environment configured spark cluster. Use 'local[*]' to run on a local box.
    :param mnist_data: A dictionary of MNIST data, with name of dataset as key, and the dataset object as value;
      if None is suplied, download it.
    :return: None
    """
    session_builder = SparkSession \
        .builder \
        .appName('MNIST Dataset Creation')
    if spark_master:
        session_builder.master(spark_master)

    spark = session_builder.getOrCreate()

    # Get training and test data
    if mnist_data is None:
        mnist_data = {
            'train': download_mnist_data(download_dir, train=True),
            'test': download_mnist_data(download_dir, train=False)
        }

    # The MNIST data is small enough to do everything here in Python
    for dset, data in mnist_data.items():
        dset_output_url = '{}/{}'.format(output_url, dset)
        # Using row_group_size_mb=1 to avoid having just a single rowgroup in this example. In a real store, the value
        # should be similar to an HDFS block size.
        with materialize_dataset(spark, dset_output_url, MnistSchema, row_group_size_mb=1):
            # List of [(idx, image, digit), ...]
            # where image is shaped as a 28x28 numpy matrix
            idx_image_digit_list = map(lambda idx_image_digit: {
                MnistSchema.idx.name: idx_image_digit[0],
                MnistSchema.digit.name: idx_image_digit[1][1],
                MnistSchema.image.name: np.array(list(idx_image_digit[1][0].getdata()), dtype=np.uint8).reshape(28, 28)
            }, enumerate(data))

            # Convert to pyspark.sql.Row
            sql_rows = map(lambda r: dict_to_spark_row(MnistSchema, r), idx_image_digit_list)

            # Write out the result
            spark.createDataFrame(sql_rows, MnistSchema.as_spark_schema()) \
                .coalesce(parquet_files_count) \
                .write \
                .option('compression', 'none') \
                .parquet(dset_output_url)