toolkits/multimodal_data_preprocessing/convert_llava_pretrain_to_wds.py (25 lines of code) (raw):

import json import os import sys import webdataset as wds from tqdm import tqdm def convert(llava_pretrain_dir): # Paths to the dataset files json_file = os.path.join(llava_pretrain_dir, 'blip_laion_cc_sbu_558k.json') output = os.path.join(llava_pretrain_dir, 'wds') if not os.path.exists(output): os.mkdir(output) # Load data with open(json_file, 'r') as f: data = json.load(f) with wds.ShardWriter(os.path.join(output, 'pretrain-%d.tar'), maxcount=10000) as shard_writer: for entry in tqdm(data): with open(os.path.join(llava_pretrain_dir, entry['image']), "rb") as img_file: image_data = img_file.read() sample = { "__key__": entry['id'], "jpg": image_data, "json": json.dumps(entry['conversations']).encode("utf-8"), } shard_writer.write(sample) print(f"Dataset successfully converted to wds") if __name__ == '__main__': convert(sys.argv[1])