vision/m4/utils/datasets/create_webdataset_tar.py (316 lines of code) (raw):

import json import logging import math import os import shutil import subprocess import uuid from functools import partial # import PIL.Image from pathlib import Path from typing import List, Optional, Union from datasets import DatasetDict, concatenate_datasets from pathos.multiprocessing import ProcessingPool as Pool from m4.training.dataset import load_hf_dataset from m4.training.types import DatasetTypes from m4.training.utils import _convert_to_rgb logger = logging.getLogger(__name__) def check_img_exception(img): try: _ = img.convert("RGB") return False except Exception as e: logger.info(e) return True # Utils for web documents def save_web_document_example_in_files(example, idx, saving_dir): example_id = f"{idx}_{str(uuid.uuid4())}" saved = 0 num_img = 0 file_paths = [] for i, (text, image) in enumerate(zip(example["texts"], example["images"])): if text is not None and text != "": text_file_path = saving_dir / f"{example_id}.{i}.text.txt" with open(text_file_path, "w") as f: f.write(text) file_paths.append(text_file_path) elif image is not None and image != "": num_img += 1 image = _convert_to_rgb(image) if check_img_exception(image): logger.info(f"Example {idx} has image with exception") continue image_path = saving_dir / f"{example_id}.{i}.image.jpeg" image.save(image_path, "jpeg") saved += 1 file_paths.append(image_path) if saved == 0: for file_path in file_paths: os.remove(file_path) return {"saved": saved, "num_img": num_img} if len(file_paths) == 0: return {"saved": saved, "num_img": num_img} metadata_file_path = saving_dir / f"{example_id}.metadata.txt" with open(metadata_file_path, "w") as f: f.write("\n".join([path.name for path in file_paths])) return {"saved": saved, "num_img": num_img} def save_web_document_example_in_files_with_num_shards( example, idx, saving_dir, num_examples_per_shard, save_shard_prefix ): shard_idx = idx // num_examples_per_shard saving_dir_shard = saving_dir / f"shard_{save_shard_prefix}{shard_idx}" return save_web_document_example_in_files(example, idx, saving_dir_shard) # Utils for image caption pairs def save_image_caption_pair_example_in_files(example, idx, saving_dir): image = example["image"] if image is None: logger.info(f"Example {idx} has None as image") return {"saved": 0, "num_img": 1} image = _convert_to_rgb(image) if check_img_exception(image): logger.info(f"Example {idx} has image with exception") return {"saved": 0, "num_img": 1} example_id = f"{idx}_{str(uuid.uuid4())}" image_path = saving_dir / f"{example_id}.image.jpeg" image.save(image_path, "jpeg") text_file_path = saving_dir / f"{example_id}.text.txt" with open(text_file_path, "w") as f: f.write(example["text"]) return {"saved": 1, "num_img": 1} def save_image_caption_pair_example_in_files_with_num_shards( example, idx, saving_dir, num_examples_per_shard, save_shard_prefix ): shard_idx = idx // num_examples_per_shard saving_dir_shard = saving_dir / f"shard_{save_shard_prefix}{shard_idx}" return save_image_caption_pair_example_in_files(example, idx, saving_dir_shard) # Utils for image/question/answer triplets (in particular for specific fine-tuning) def save_image_question_answer_triplet_example_in_files(example, idx, saving_dir): image = example["image"] if image is None: logger.info(f"Example {idx} has None as image") return {"saved": 0, "num_img": 1} image = _convert_to_rgb(image) if check_img_exception(image): logger.info(f"Example {idx} has image with exception") return {"saved": 0, "num_img": 1} example_id = f"{idx}_{str(uuid.uuid4())}" image_path = saving_dir / f"{example_id}.image.jpeg" image.save(image_path, "jpeg") question_file_path = saving_dir / f"{example_id}.question.txt" with open(question_file_path, "w") as f: f.write(example["question"]) answer_file_path = saving_dir / f"{example_id}.answer.txt" with open(answer_file_path, "w") as f: f.write(example["answer"]) return {"saved": 1, "num_img": 1} def save_image_question_answer_triplet_example_in_files_with_num_shards( example, idx, saving_dir, num_examples_per_shard, save_shard_prefix ): shard_idx = idx // num_examples_per_shard saving_dir_shard = saving_dir / f"shard_{save_shard_prefix}{shard_idx}" return save_image_question_answer_triplet_example_in_files(example, idx, saving_dir_shard) # Utils for sft datasets now that they are all under the same format def save_sft_example_in_files(example, idx, saving_dir): example_id = f"{idx}_{str(uuid.uuid4())}" saved = 0 num_img = 0 file_paths = [] for i, image in enumerate(example["images"]): num_img += 1 image = _convert_to_rgb(image) if check_img_exception(image): logger.info(f"Example {idx} has image with exception") continue if check_img_exception(image): logger.info(f"Example {idx} has image with exception") continue image_path = saving_dir / f"{example_id}.{i}.image.jpeg" image.save(image_path, "jpeg") saved += 1 file_paths.append(image_path) for i, text in enumerate(example["texts"]): text_file_path = saving_dir / f"{example_id}.{i + num_img}.text.txt" with open(text_file_path, "w") as f: json.dump(text, f) # Dump the user/assistant dict as a string into a text file file_paths.append(text_file_path) if len(file_paths) == 0: return {"saved": saved, "num_img": num_img} metadata_file_path = saving_dir / f"{example_id}.metadata.txt" with open(metadata_file_path, "w") as f: f.write("\n".join([path.name for path in file_paths])) return {"saved": saved, "num_img": num_img} def save_sft_example_in_files_with_num_shards(example, idx, saving_dir, num_examples_per_shard, save_shard_prefix): shard_idx = idx // num_examples_per_shard shard_idx_padded = str(shard_idx).zfill(7) saving_dir_shard = saving_dir / f"shard_{save_shard_prefix}{shard_idx_padded}" return save_sft_example_in_files(example, idx, saving_dir_shard) # General utils def save_example_in_files(example, idx, saving_dir, ds_type): if ds_type == DatasetTypes.WEB_DOCUMENTS: saved = save_web_document_example_in_files(example, idx, saving_dir) elif ds_type == DatasetTypes.IMAGE_CAPTION_PAIRS: saved = save_image_caption_pair_example_in_files(example, idx, saving_dir) elif (ds_type == DatasetTypes.DOCVQA) or (ds_type == DatasetTypes.VQAV2_TASK_FINETUNING): saved = save_image_question_answer_triplet_example_in_files(example, idx, saving_dir) elif ds_type == DatasetTypes.SFT: saved = save_sft_example_in_files(example, idx, saving_dir) else: raise ValueError(f"Unsupported dataset type {ds_type}") return saved def save_example_in_files_with_num_shards( example, idx, saving_dir, ds_type, num_examples_per_shard, save_shard_prefix ): if ds_type == DatasetTypes.WEB_DOCUMENTS: saved = save_web_document_example_in_files_with_num_shards( example, idx, saving_dir, num_examples_per_shard, save_shard_prefix ) elif ds_type == DatasetTypes.IMAGE_CAPTION_PAIRS: saved = save_image_caption_pair_example_in_files_with_num_shards( example, idx, saving_dir, num_examples_per_shard, save_shard_prefix ) elif (ds_type == DatasetTypes.DOCVQA) or (ds_type == DatasetTypes.VQAV2_TASK_FINETUNING): saved = save_image_question_answer_triplet_example_in_files_with_num_shards( example, idx, saving_dir, num_examples_per_shard, save_shard_prefix ) elif ds_type == DatasetTypes.SFT: saved = save_sft_example_in_files_with_num_shards( example, idx, saving_dir, num_examples_per_shard, save_shard_prefix ) else: raise ValueError(f"Unsupported dataset type {ds_type}") return saved def export_dataset_all_shard_idx_to_tar( hf_datasets_paths: List[Union[str, Path]], saving_dir: Union[str, Path], ds_type: DatasetTypes, num_examples_per_shard: int, s3_uri: Optional[str] = None, num_proc: Optional[int] = None, min_num_shards: Optional[int] = None, save_shard_prefix: str = "", shard_idx: Optional[int] = None, save_shard_idx: Optional[str] = None, ): if save_shard_idx is not None: raise NotImplementedError("Use of `save_shard_idx` has been deprecated.") if num_proc is None: # by default the value of num_proc will be the minimum between 6 and the number of cpus num_proc = min(6, os.cpu_count()) logger.info("Start loading the dataset") dataset_list = [load_hf_dataset(str(hf_dataset_path)) for hf_dataset_path in hf_datasets_paths] ds = concatenate_datasets(dataset_list) if isinstance(ds, DatasetDict): raise ValueError("DatasetDict not supported") if num_examples_per_shard is None: num_shards = 1 else: num_shards = math.ceil(len(ds) / num_examples_per_shard) if min_num_shards is not None and num_shards < min_num_shards: num_examples_per_shard = len(ds) // min_num_shards logger.info( f"Number of examples per shard is too low, setting it to {num_examples_per_shard}. Without this, the" f" number of shards would be {num_shards} which is lower than the minimum number of shards" f" {min_num_shards}" ) num_shards = min_num_shards num_examples_per_shard = len(ds) // num_shards num_shards = len(ds) // num_examples_per_shard logger.info(f"Number of shards: {num_shards} and number of examples per shard: {num_examples_per_shard}") if shard_idx is None: for idx in range(num_shards + 1): idx_leading_0s = str(idx).zfill(7) saving_dir_shard = saving_dir / f"shard_{save_shard_prefix}{idx_leading_0s}" saving_dir_shard.mkdir(parents=True, exist_ok=True) logger.info(f"The dataset has {len(ds)} examples and the columns are {ds.column_names}") ds_saved = ds.map( partial( save_example_in_files_with_num_shards, saving_dir=saving_dir, ds_type=ds_type, num_examples_per_shard=num_examples_per_shard, save_shard_prefix=save_shard_prefix, ), with_indices=True, num_proc=num_proc, load_from_cache_file=False, remove_columns=ds.column_names, ) else: ds = ds.shard(num_shards=num_shards, index=shard_idx) saving_dir_shard = saving_dir / f"shard_{save_shard_prefix}{shard_idx}" saving_dir_shard.mkdir(parents=True, exist_ok=True) logger.info(f"The dataset has {len(ds)} examples and the columns are {ds.column_names}") ds_saved = ds.map( partial(save_example_in_files, saving_dir=saving_dir_shard, ds_type=ds_type), with_indices=True, num_proc=num_proc, load_from_cache_file=False, ) finished_file_path = saving_dir / f"shard_{save_shard_prefix}_finished.txt" finished_file_path.touch() num_images = sum(ds_saved["num_img"]) num_saved = sum(ds_saved["saved"]) logger.info( f"Shard {save_shard_prefix} has {num_images} images and out of {num_saved} saved" f" ({num_saved / num_images * 100:.2f}" ) def tar_shard_and_send_to_s3(saving_dir_shard): # check if the shard exists and is not empty if not os.path.exists(saving_dir_shard) or not os.listdir(saving_dir_shard): return tar_file = saving_dir_shard.parent / f"{saving_dir_shard.name}.tar" # Create tar file tar_cmd = ["tar", "--sort=name", "-cf", str(tar_file), "-C", str(saving_dir_shard), "."] subprocess.run(tar_cmd, check=True) # Remove original directory shutil.rmtree(saving_dir_shard, ignore_errors=True) # Upload to S3 if necessary if s3_uri is not None: s3_uri_file = f"{s3_uri}/{tar_file.name}" sync_cmd = ["s5cmd", "cp", str(tar_file), s3_uri_file] subprocess.run(sync_cmd, check=True) return if shard_idx is None: args_pool = [saving_dir / f"shard_{save_shard_prefix}{str(idx).zfill(7)}" for idx in range(num_shards + 1)] else: args_pool = [saving_dir / f"shard_{save_shard_prefix}{shard_idx}"] pool = Pool(num_proc) results = pool.amap(tar_shard_and_send_to_s3, args_pool) results = results.get() return 0 def export_dataset_to_tar( hf_datasets_paths: List[Union[str, Path]], saving_dir: Union[str, Path], ds_type: DatasetTypes, num_examples_per_shard: int, num_proc: Optional[int] = None, ): return export_dataset_all_shard_idx_to_tar( hf_datasets_paths=hf_datasets_paths, saving_dir=saving_dir, ds_type=ds_type, num_examples_per_shard=num_examples_per_shard, num_proc=num_proc, ) def export_dataset_shard_idx_to_tar( hf_datasets_paths: List[Union[str, Path]], saving_dir: Union[str, Path], ds_type: DatasetTypes, num_examples_per_shard: int, s3_uri: Optional[str] = None, num_proc: Optional[int] = None, shard_idx: int = 0, min_num_shards: Optional[int] = None, save_shard_idx: Optional[str] = None, ): logger.warning( "`export_dataset_shard_idx_to_tar` is deprecated, please favor `export_dataset_all_shard_idx_to_tar`." ) return export_dataset_all_shard_idx_to_tar( hf_datasets_paths=hf_datasets_paths, saving_dir=saving_dir, ds_type=ds_type, num_examples_per_shard=num_examples_per_shard, s3_uri=s3_uri, num_proc=num_proc, min_num_shards=min_num_shards, save_shard_prefix="", shard_idx=shard_idx, save_shard_idx=save_shard_idx, ) if __name__ == "__main__": hf_datasets_paths = ["HuggingFaceM4/tmp-pmd-synthetic-testing:100.unique"] saving_dir = Path("/home/lucile/data/tmp-pmd-synthetic-testing-100-unique-tar") num_examples_per_shard = 20 num_proc = 32 ds_type = DatasetTypes.IMAGE_CAPTION_PAIRS export_dataset_to_tar(hf_datasets_paths, saving_dir, ds_type, num_examples_per_shard, num_proc)