def write()

in pyspark_huggingface/huggingface_sink.py [0:0]


    def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage:
        import io

        from huggingface_hub import CommitOperationAdd
        from pyarrow import parquet as pq
        from pyspark import TaskContext
        from pyspark.sql.pandas.types import to_arrow_schema

        # Get the current partition ID. Use this to generate unique filenames for each partition.
        context = TaskContext.get()
        partition_id = context.partitionId() if context else 0

        api = self._get_api()

        schema = to_arrow_schema(self.schema)
        num_files = 0
        additions = []

        # TODO: Evaluate the performance of using a temp file instead of an in-memory buffer.
        with io.BytesIO() as parquet:

            def flush(writer: pq.ParquetWriter):
                """
                Upload the current Parquet file and reset the buffer.
                """
                writer.close()  # Close the writer to flush the buffer
                nonlocal num_files
                name = (
                    f"{self.prefix}-{self.uuid}-part-{partition_id}-{num_files}.parquet"
                )
                num_files += 1
                parquet.seek(0)

                addition = CommitOperationAdd(
                    path_in_repo=name, path_or_fileobj=parquet
                )
                api.preupload_lfs_files(
                    repo_id=self.repo_id,
                    additions=[addition],
                    repo_type=self.repo_type,
                    revision=self.revision,
                )
                additions.append(addition)

                # Reuse the buffer for the next file
                parquet.seek(0)
                parquet.truncate()

            """
            Write the Parquet files, flushing the buffer when the file size exceeds the limit.
            Limiting the size is necessary because we are writing them in memory.
            """
            while True:
                with pq.ParquetWriter(parquet, schema, **self.kwargs) as writer:
                    num_batches = 0
                    for batch in iterator:  # Start iterating from where we left off
                        writer.write_batch(batch, row_group_size=self.row_group_size)
                        num_batches += 1
                        if parquet.tell() > self.max_bytes_per_file:
                            flush(writer)
                            break  # Start a new file
                    else:  # Finished writing all batches
                        if num_batches > 0:
                            flush(writer)
                        break  # Exit while loop

        return HuggingFaceCommitMessage(additions=additions)