in pyspark_huggingface/huggingface_source.py [0:0]
def read(self, partition: Shard):
columns = [field.name for field in self.schema.fields]
if self.streaming_dataset:
shard = self.streaming_dataset.shard(num_shards=self.streaming_dataset.num_shards, index=partition.index)
if shard._ex_iterable.iter_arrow:
for _, pa_table in shard._ex_iterable.iter_arrow():
yield from pa_table.select(columns).to_batches()
else:
for _, example in shard:
yield example
else:
self.builder.download_and_prepare()
dataset = self.builder.as_dataset(self.split)
# Get the underlying arrow table of the dataset
table = dataset._data
yield from table.select(columns).to_batches()