in pyspark_huggingface/huggingface_sink.py [0:0]
def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ignore[override]
"""
Commit the pre-uploaded Parquet files to the HuggingFace Hub, renaming them to match the expected format:
`{split}-{current:05d}-of-{total:05d}.parquet`.
Also delete or rename existing files of the split, depending on the mode.
"""
from huggingface_hub import CommitOperationCopy, CommitOperationDelete
from huggingface_hub.hf_api import RepoFile, RepoFolder
api = self._get_api()
additions = [addition for message in messages for addition in message.additions]
operations = {}
count_new = len(additions)
count_existing = 0
def format_path(i):
return f"{self.prefix}-{i:05d}-of-{count_new + count_existing:05d}.parquet"
def rename(old_path, new_path):
if old_path != new_path:
yield CommitOperationCopy(
src_path_in_repo=old_path, path_in_repo=new_path
)
yield CommitOperationDelete(path_in_repo=old_path)
# In overwrite mode, delete existing files
if self.overwrite:
for obj in self._list_split(api):
# Delete old file
operations[obj.path] = CommitOperationDelete(
path_in_repo=obj.path, is_folder=isinstance(obj, RepoFolder)
)
# In append mode, rename existing files to have the correct total number of parts
else:
rename_operations = []
existing = list(
obj for obj in self._list_split(api) if isinstance(obj, RepoFile)
)
count_existing = len(existing)
for i, obj in enumerate(existing):
new_path = format_path(i)
rename_operations.extend(rename(obj.path, new_path))
# Rename files in a separate commit to prevent them from being overwritten by new files of the same name
self._create_commits(
api,
operations=rename_operations,
message="Rename existing files before uploading new files using PySpark",
)
# Rename additions, putting them after existing files if any
for i, addition in enumerate(additions):
addition.path_in_repo = format_path(i + count_existing)
# Overwrite the deletion operation if the file already exists
operations[addition.path_in_repo] = addition
# Upload the new files
self._create_commits(
api,
operations=list(operations.values()),
message="Upload using PySpark",
)