toolkits/multimodal_data_preprocessing/convert_custom_dataset_to_wds_chatml.py (98 lines of code) (raw):

import json import os import webdataset as wds from tqdm import tqdm import yaml from argparse import ArgumentParser import cv2 from webdataset.writer import default_handlers, add_handlers, imageencoder import pickle from megatron.energon.epathlib import EPath from megatron.energon.flavors import BaseWebdatasetFactory def convert(dataset_dir, json_name, sort_function=sorted, max_count=10000): """ Here we provide an example to convert llava-pretrain dataset to ChatMLSample """ # Paths to the dataset files json_file = os.path.join(dataset_dir, json_name) output = os.path.join(dataset_dir, 'wds') if not os.path.exists(output): os.mkdir(output) # Load data with open(json_file, 'r') as f: data = json.load(f) # custom webdataset ShardWriter Encoder add_handlers(default_handlers, "jpgs", lambda data: pickle.dumps([imageencoder(d, "jpg") for d in data])) add_handlers(default_handlers, "videos", lambda data: pickle.dumps([[imageencoder(d, "jpg") for d in video] for video in data])) has_idx = None with wds.ShardWriter(os.path.join(output, 'pretrain-%d.tar'), maxcount=max_count) as shard_writer: for idx, entry in enumerate(tqdm(data)): # NOTE: read a dataset in sharegpt format image_datas = [] for image in entry.pop('images', []): image_datas.append(cv2.imread(os.path.join(dataset_dir, image), cv2.IMREAD_UNCHANGED)) video_datas = [] second_per_grid_ts = [] for video in entry.pop('videos', []): video_noext, _ = os.path.splitext(video) frame_folder = os.path.join(dataset_dir, video_noext) # NOTE: we implicitly require a `${frame_folder}.json`` file containing fps rates of each video # otherwise fps will be regarded as `1` by default. if os.path.exists(frame_folder + '.json'): with open(frame_folder + '.json', 'r') as f: fps = float(json.load(f)['fps']) else: fps = 2.0 frames = [] for frame in sort_function(os.listdir(frame_folder)): frames.append(cv2.imread(os.path.join(frame_folder, frame), cv2.IMREAD_UNCHANGED)) if len(frames) % 2 == 1: frames = frames[:-1] video_datas.append(frames) second_per_grid_ts.append(1 / fps) if has_idx is None: has_idx = 'id' in entry assert has_idx == ('id' in entry), "All entries should either all contain idx or not." sample = { "__key__": entry.pop('id', str(idx)), "jpgs": image_datas, 'videos': video_datas, "json": json.dumps({ 'conversations': entry['conversations'], 'second_per_grid_ts': second_per_grid_ts }).encode("utf-8"), } shard_writer.write(sample) print(f"Dataset successfully converted to wds") return output def generate_configs(path: EPath, split, shuffle_tars=True, num_workers=32): path = path.absolute() all_tars = list(path.glob("**/*.tar")) + list(path.glob("**/*.tgz")) all_tars = [str(p.relative_to(path)) for p in sorted(all_tars)] split_parts_ratio = [("train", split[0]), ("val", split[1]), ("test", split[2])] split_parts_patterns = None # NOTE: generate .info.yaml and split.yaml _ = BaseWebdatasetFactory.prepare_dataset( path, all_tars, split_parts_ratio=split_parts_ratio, split_parts_patterns=split_parts_patterns, tar_index_only=False, shuffle_seed=42 if shuffle_tars else None, workers=num_workers, ) # NOTE: dump dataset.yaml metadata = { '__class__': 'ChatMLWebdataset', '__module__': 'megatron_patch.data.energon.chatml', 'field_map': { 'imgs': 'jpgs', 'videos': 'videos', 'conversation': 'json' } } with open(os.path.join(path.url, '.nv-meta', 'dataset.yaml'), 'w') as f: yaml.safe_dump(metadata, f) if __name__ == '__main__': argparser = ArgumentParser() argparser.add_argument('--dataset-root', required=True, type=str) argparser.add_argument('--json', default='dataset.json', type=str) argparser.add_argument('--max-samples-per-tar', default=10000, type=float) argparser.add_argument('--train-split', default=9, type=float) argparser.add_argument('--val-split', default=1, type=float) argparser.add_argument('--test-split', default=0, type=float) args = argparser.parse_args() output_dir = convert(args.dataset_root, args.json, max_count=args.max_samples_per_tar) print(f"Generating Configurations") # NOTE: split_ratio: train/val/test split=[args.train_split, args.val_split, args.test_split] generate_configs(EPath(output_dir), split) print(f"Configurations Generated")