def read()

in pyspark_huggingface/huggingface_source.py [0:0]


    def read(self, partition: Shard):
        columns = [field.name for field in self.schema.fields]
        if self.streaming_dataset:
            shard = self.streaming_dataset.shard(num_shards=self.streaming_dataset.num_shards, index=partition.index)
            if shard._ex_iterable.iter_arrow:
                for _, pa_table in shard._ex_iterable.iter_arrow():
                    yield from pa_table.select(columns).to_batches()
            else:
                for _, example in shard:
                    yield example
        else:
            self.builder.download_and_prepare()
            dataset = self.builder.as_dataset(self.split)
            # Get the underlying arrow table of the dataset
            table = dataset._data
            yield from table.select(columns).to_batches()