in pyspark_huggingface/huggingface_sink.py [0:0]
def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage:
import io
from huggingface_hub import CommitOperationAdd
from pyarrow import parquet as pq
from pyspark import TaskContext
from pyspark.sql.pandas.types import to_arrow_schema
# Get the current partition ID. Use this to generate unique filenames for each partition.
context = TaskContext.get()
partition_id = context.partitionId() if context else 0
api = self._get_api()
schema = to_arrow_schema(self.schema)
num_files = 0
additions = []
# TODO: Evaluate the performance of using a temp file instead of an in-memory buffer.
with io.BytesIO() as parquet:
def flush(writer: pq.ParquetWriter):
"""
Upload the current Parquet file and reset the buffer.
"""
writer.close() # Close the writer to flush the buffer
nonlocal num_files
name = (
f"{self.prefix}-{self.uuid}-part-{partition_id}-{num_files}.parquet"
)
num_files += 1
parquet.seek(0)
addition = CommitOperationAdd(
path_in_repo=name, path_or_fileobj=parquet
)
api.preupload_lfs_files(
repo_id=self.repo_id,
additions=[addition],
repo_type=self.repo_type,
revision=self.revision,
)
additions.append(addition)
# Reuse the buffer for the next file
parquet.seek(0)
parquet.truncate()
"""
Write the Parquet files, flushing the buffer when the file size exceeds the limit.
Limiting the size is necessary because we are writing them in memory.
"""
while True:
with pq.ParquetWriter(parquet, schema, **self.kwargs) as writer:
num_batches = 0
for batch in iterator: # Start iterating from where we left off
writer.write_batch(batch, row_group_size=self.row_group_size)
num_batches += 1
if parquet.tell() > self.max_bytes_per_file:
flush(writer)
break # Start a new file
else: # Finished writing all batches
if num_batches > 0:
flush(writer)
break # Exit while loop
return HuggingFaceCommitMessage(additions=additions)