examples/keras_spark_mnist.py (93 lines of code) (raw):

import argparse import os import subprocess from distutils.version import LooseVersion import numpy as np import pyspark import pyspark.sql.types as T from pyspark import SparkConf from pyspark.ml.evaluation import MulticlassClassificationEvaluator if LooseVersion(pyspark.__version__) < LooseVersion('3.0.0'): from pyspark.ml.feature import OneHotEncoderEstimator as OneHotEncoder else: from pyspark.ml.feature import OneHotEncoder from pyspark.sql import SparkSession from pyspark.sql.functions import udf import tensorflow as tf from tensorflow import keras from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Dropout, Flatten from tensorflow.keras.layers import Conv2D, MaxPooling2D import horovod.spark.keras as hvd from horovod.spark.common.store import Store parser = argparse.ArgumentParser(description='Keras Spark MNIST Example', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--master', help='spark master to connect to') parser.add_argument('--num-proc', type=int, help='number of worker processes for training, default: `spark.default.parallelism`') parser.add_argument('--batch-size', type=int, default=128, help='input batch size for training') parser.add_argument('--epochs', type=int, default=12, help='number of epochs to train') parser.add_argument('--work-dir', default='/tmp', help='temporary working directory to write intermediate files (prefix with hdfs:// to use HDFS)') parser.add_argument('--data-dir', default='/tmp', help='location of the training dataset in the local filesystem (will be downloaded if needed)') if __name__ == '__main__': args = parser.parse_args() # Initialize SparkSession conf = SparkConf().setAppName('keras_spark_mnist').set('spark.sql.shuffle.partitions', '16') if args.master: conf.setMaster(args.master) elif args.num_proc: conf.setMaster('local[{}]'.format(args.num_proc)) spark = SparkSession.builder.config(conf=conf).getOrCreate() # Setup our store for intermediate data store = Store.create(args.work_dir) # Download MNIST dataset data_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2' libsvm_path = os.path.join(args.data_dir, 'mnist.bz2') if not os.path.exists(libsvm_path): subprocess.check_output(['wget', data_url, '-O', libsvm_path]) # Load dataset into a Spark DataFrame df = spark.read.format('libsvm') \ .option('numFeatures', '784') \ .load(libsvm_path) # One-hot encode labels into SparseVectors encoder = OneHotEncoder(inputCols=['label'], outputCols=['label_vec'], dropLast=False) model = encoder.fit(df) train_df = model.transform(df) # Train/test split train_df, test_df = train_df.randomSplit([0.9, 0.1]) # Disable GPUs when building the model to prevent memory leaks if LooseVersion(tf.__version__) >= LooseVersion('2.0.0'): # See https://github.com/tensorflow/tensorflow/issues/33168 os.environ['CUDA_VISIBLE_DEVICES'] = '-1' else: keras.backend.set_session(tf.Session(config=tf.ConfigProto(device_count={'GPU': 0}))) # Define the Keras model without any Horovod-specific parameters model = Sequential() model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1))) model.add(Conv2D(64, (3, 3), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(128, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(10, activation='softmax')) optimizer = keras.optimizers.Adadelta(1.0) loss = keras.losses.categorical_crossentropy # Train a Horovod Spark Estimator on the DataFrame keras_estimator = hvd.KerasEstimator(num_proc=args.num_proc, store=store, model=model, optimizer=optimizer, loss=loss, metrics=['accuracy'], feature_cols=['features'], label_cols=['label_vec'], batch_size=args.batch_size, epochs=args.epochs, verbose=1) keras_model = keras_estimator.fit(train_df).setOutputCols(['label_prob']) # Evaluate the model on the held-out test DataFrame pred_df = keras_model.transform(test_df) argmax = udf(lambda v: float(np.argmax(v)), returnType=T.DoubleType()) pred_df = pred_df.withColumn('label_pred', argmax(pred_df.label_prob)) evaluator = MulticlassClassificationEvaluator(predictionCol='label_pred', labelCol='label', metricName='accuracy') print('Test accuracy:', evaluator.evaluate(pred_df)) spark.stop()