def run()

in examples/spark_dataset_converter/pytorch_converter_example.py [0:0]


def run(data_dir):
    # Get SparkSession
    spark = SparkSession.builder \
        .master("local[2]") \
        .appName("petastorm.spark pytorch_example") \
        .getOrCreate()

    # Load and preprocess data using Spark
    df = spark.read.format("libsvm") \
        .option("numFeatures", "784") \
        .load(data_dir) \
        .select(col("features"), col("label").cast("long").alias("label"))

    # Randomly split data into train and test dataset
    df_train, df_test = df.randomSplit([0.9, 0.1], seed=12345)

    # Set a cache directory for intermediate data.
    # The path should be accessible by both Spark workers and driver.
    spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF,
                   "file:///tmp/petastorm/cache/torch-example")

    converter_train = make_spark_converter(df_train)
    converter_test = make_spark_converter(df_test)

    def train_and_evaluate(_=None):
        with converter_train.make_torch_dataloader() as loader:
            model = train(loader)

        with converter_test.make_torch_dataloader(num_epochs=1) as loader:
            accuracy = test(model, loader)
        return accuracy

    # Train and evaluate the model on the local machine
    accuracy = train_and_evaluate()
    logging.info("Train and evaluate the model on the local machine.")
    logging.info("Accuracy: %.6f", accuracy)

    # Train and evaluate the model on a spark worker
    accuracy = spark.sparkContext.parallelize(range(1)).map(train_and_evaluate).collect()[0]
    logging.info("Train and evaluate the model remotely on a spark worker, "
                 "which can be used for distributed hyperparameter tuning.")
    logging.info("Accuracy: %.6f", accuracy)

    # Cleanup
    converter_train.delete()
    converter_test.delete()
    spark.stop()