In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import warnings
warnings.filterwarnings('ignore')

# Spark HuggingFace Connector Demo

## Create a Spark Session

In [None]:
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder
    .config("spark.executor.memory", "20G") 
    .getOrCreate()
)

## Load a dataset as a Spark DataFrame

By default the connector is using Streaming Dataset: `load_dataset(..., streaming=True)`

In [2]:
df = spark.read.format("huggingface").load("cornell-movie-review-data/rotten_tomatoes")

In [3]:
df.printSchema()

root
 |-- text: string (nullable = true)
 |-- label: long (nullable = true)



In [4]:
# Cache the dataframe to avoid re-downloading data. Note this should be used for small datasets.
df.cache()

DataFrame[text: string, label: bigint]

In [5]:
# Trigger the cache computation
df.count()

24/11/27 15:11:14 WARN CheckAllocator: More than one DefaultAllocationManager on classpath. Choosing first found
                                                                                

8530

In [6]:
df.head()

Row(text='the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .', label=1)

In [7]:
# Then you can operate on this dataframe
df.filter(df.label == 0).count()

4265

## Load a Dataset with a configuration/subset
Some datasets require explicitly specifying the config name. You can pass this as a data source option.

## Load a different split
You can specify the `split` data source option

In [8]:
test_df = (
    spark.read.format("huggingface")
    .option("split", "test")
    .load("cornell-movie-review-data/rotten_tomatoes")
)

In [9]:
test_df.cache()

DataFrame[text: string, label: bigint]

In [10]:
test_df.count()

                                                                                

1066

In [11]:
test_df.show(n=5)

+--------------------+-----+
|                text|label|
+--------------------+-----+
|lovingly photogra...|    1|
|consistently clev...|    1|
|it's like a " big...|    1|
|the story gives a...|    1|
|red dragon " neve...|    1|
+--------------------+-----+
only showing top 5 rows


## Load a dataset with multiple shards

This example is using the [amazon_popularity dataset](https://huggingface.co/datasets/fancyzhx/amazon_polarity) which has 4 shards (for train split)

In [12]:
df = spark.read.format("huggingface").load("fancyzhx/amazon_polarity")

In [13]:
# You can see there are 4 partitions, each correspond to one shard.
df.rdd.getNumPartitions()

4

## Load a dataset without streaming

This is equivalent to `load_dataset(..., streaming=False)`

In [None]:
df = spark.read.format("huggingface").option("streaming", "false").load("stanfordnlp/imdb")

In [15]:
df.show(n=5)

[Stage 13:>                                                         (0 + 1) / 1]

+--------------------+-----+
|                text|label|
+--------------------+-----+
|I rented I AM CUR...|    0|
|"I Am Curious: Ye...|    0|
|If only to avoid ...|    0|
|This film was pro...|    0|
|Oh, brother...aft...|    0|
+--------------------+-----+
only showing top 5 rows


                                                                                

In [16]:
df.filter(df.label == 1).show(n=5)

[Stage 14:>                                                         (0 + 1) / 1]

+--------------------+-----+
|                text|label|
+--------------------+-----+
|Zentropa has much...|    1|
|Zentropa is the m...|    1|
|Lars Von Trier is...|    1|
|*Contains spoiler...|    1|
|That was the firs...|    1|
+--------------------+-----+
only showing top 5 rows


                                                                                