in examples/spark_dataset_converter/pytorch_converter_example.py [0:0]
def run(data_dir):
# Get SparkSession
spark = SparkSession.builder \
.master("local[2]") \
.appName("petastorm.spark pytorch_example") \
.getOrCreate()
# Load and preprocess data using Spark
df = spark.read.format("libsvm") \
.option("numFeatures", "784") \
.load(data_dir) \
.select(col("features"), col("label").cast("long").alias("label"))
# Randomly split data into train and test dataset
df_train, df_test = df.randomSplit([0.9, 0.1], seed=12345)
# Set a cache directory for intermediate data.
# The path should be accessible by both Spark workers and driver.
spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF,
"file:///tmp/petastorm/cache/torch-example")
converter_train = make_spark_converter(df_train)
converter_test = make_spark_converter(df_test)
def train_and_evaluate(_=None):
with converter_train.make_torch_dataloader() as loader:
model = train(loader)
with converter_test.make_torch_dataloader(num_epochs=1) as loader:
accuracy = test(model, loader)
return accuracy
# Train and evaluate the model on the local machine
accuracy = train_and_evaluate()
logging.info("Train and evaluate the model on the local machine.")
logging.info("Accuracy: %.6f", accuracy)
# Train and evaluate the model on a spark worker
accuracy = spark.sparkContext.parallelize(range(1)).map(train_and_evaluate).collect()[0]
logging.info("Train and evaluate the model remotely on a spark worker, "
"which can be used for distributed hyperparameter tuning.")
logging.info("Accuracy: %.6f", accuracy)
# Cleanup
converter_train.delete()
converter_test.delete()
spark.stop()