scripts/pre_encode.py (350 lines of code) (raw):

# This script is used to pre encode coyo, laion 6a, and laion 5a. # # It can be run as both a standalone job or via slurm. When run via slurm, be # sure to pass `--slurm` so the script can split shards amongst workers based on # the env vars `$SLURM_NTASKS` and `$SLURM_PROCID`. It is intended that one copy # of the script is launched per gpu, and cpu access is controlled implicitly # through slurm setting `$CUDA_VISIBLE_DEVICES`. See # ../slurm_scrips/{pre_encoded_laion_6, pre_encode_laion_5, # pre_encode_coyo}.slurm for example sbatch scripts. # # Benchmarks: # COYO) 64.1 GPU * sec / shard # laion) 75 GPU * sec / shard # # To convert a time per shard into a time to convert the # whole dataset, use # X (GPU * sec / shard) * Y shards * 1/8 (nodes/GPU) * 1/Z nodes = seconds to encode Y shards # # Shard counts: # COYO) 74,752 shards (0-74,751) # laion 6a) 1,211 shards (0 - 1,210) # laion 5a) 60,581 shards (0 - 60,580) # # Encoding times using 8 nodes: # COYO) 20h48m # laion 6a) 23.4 minutes # laion 5a) 19h43m import argparse import concurrent.futures import logging import os import re from collections import OrderedDict from threading import Lock import numpy as np import torch import torchvision.transforms.functional as TF import webdataset as wds from torch.utils.data import DataLoader from torchvision.transforms import InterpolationMode from transformers import CLIPTextModel, CLIPTokenizerFast from muse import PaellaVQModel, VQGANModel torch.set_float32_matmul_precision("high") torch.set_grad_enabled(False) PAELLA_F8_VQVAE = "openMUSE/paellavq-f8-8192-laion" VQGAN_F16_VQVAE = "openMUSE/vqgan-f16-8192-laion" CLIP = "openMUSE/CLIP-ViT-L-14-DataComp.XL-s13B-b90K-penultimate" PAELLA_F8_VQVAE_EXT = f"{'.'.join(PAELLA_F8_VQVAE.split('/'))}.pth" VQGAN_F16_VQVAE_EXT = f"{'.'.join(VQGAN_F16_VQVAE.split('/'))}.pth" CLIP_EXT = f"{'.'.join(CLIP.split('/'))}.pth" LAION_AESTHETICS_V2_5_PLUS = "s3://hf-datasets-laion-5b-us-west-2/glacier/laion-data/laion-aesthetics-v2-5-plus-data" LAION_AESTHETICS_V2_6_PLUS = "s3://muse-datasets/laion-aesthetic6plus-data" COYO = "s3://hf-datasets-coyo-700m-us-west-2/data" LAION_AESTHETICS_V2_5_PLUS_PRE_ENCODED = "s3://muse-datasets/hf-datasets-laion-aesthetics-v2-5-plus-data-pre-encoded" LAION_AESTHETICS_V2_6_PLUS_PRE_ENCODED = "s3://muse-datasets/hf-datasets-laion-aesthetic6plus-data-pre-encoded" COYO_PRE_ENCODED = "s3://muse-datasets/hf-datasets-coyo-700m-pre-encoded" logger = logging.getLogger(__name__) tar_regex = r"\/([^\/]+\.tar)" def get_tar_file_name(url): match = re.search(tar_regex, url) assert match is not None, url tar_file_name = match.group(1) return tar_file_name def format_shard_number(shard_n: int): return "{:0>{}}".format(shard_n, 5) class Uploads: """ Uploads manages the post encoding steps, both CUDA -> cpu and the s3 upload. In order to avoid an expensive cuda sync event of the encode for every batch, instead "submit" the entirety of the post processing to a thread pool. Once the thread pool is full, we hand over the entirety of the thread pool to the python interpreter. This effectively allows multiple encoding batches to execute at once. At a 160 batch size, this uses <40 GB VRAM. TODO - probably would be better to wait until the thread pool is full and then execute just the least recent post processing? This could even be done without a thread pool or with a single thread, since it's executing one job at a time. Hmmm. The class must manage 1) the thread pool 2) the list of pending futures that have been submitted 3) a list of tar writers to upload results For the list of tar writers, we keep at most 5 open at a time. When we need to open an additional writer, we close the earliest opened one assuming that we have finished writing to it as the archives are read sequentially. This is an assumption but 5 is a safe buffer as we realistically will never be writing to more than 2 at a time for a reasonably sized thread pool. The list of tar writers is managed with a global lock because it opens a sub process and iirc Popen is not thread safe. Additionally each tar writer is managed with its own lock because writes are not thread safe and can corrupt the archive. """ def __init__(self, skip_upload, upload_to, num_writing_threads): self.open_lock = Lock() self.uploads = OrderedDict() self.skip_upload = skip_upload self.upload_to = upload_to self.futures = [] self.num_writing_threads = num_writing_threads self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.num_writing_threads) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): # Finish all pending encodings [x.result() for x in concurrent.futures.as_completed(self.futures)] self.executor.shutdown(wait=True) # Close all unclosed file writes for tar_file_name, tar_writer in self.uploads.items(): tar_writer["writer"].close() return False def submit( self, __key__, __url__, encoder_hidden_states, attention_mask_lengths, encoded_image_f8, encoded_image_f16, metadata, ): future = self.executor.submit( self._upload_thread_entrypoint, __key__, __url__, encoder_hidden_states, attention_mask_lengths, encoded_image_f8, encoded_image_f16, metadata, ) self.futures.append(future) # Give cuda some time to complete the encodings before moving to cpu and uploading if len(self.futures) == self.num_writing_threads: [x.result() for x in concurrent.futures.as_completed(self.futures)] self.futures = [] def _upload_thread_entrypoint( self, __key__, __url__, encoder_hidden_states, attention_mask_lengths, encoded_image_f8, encoded_image_f16, metadata, ): encoder_hidden_states = torch.unbind(encoder_hidden_states) encoded_image_f8 = torch.unbind(encoded_image_f8) encoded_image_f16 = torch.unbind(encoded_image_f16) for ( __key__, __url__, encoded_image_f8, encoded_image_f16, encoder_hidden_states, attention_mask_length, metadata, ) in zip( __key__, __url__, encoded_image_f8, encoded_image_f16, encoder_hidden_states, attention_mask_lengths, metadata, ): encoded_image_f8 = encoded_image_f8.clone().to("cpu") encoded_image_f16 = encoded_image_f16.clone().to("cpu") encoder_hidden_states = encoder_hidden_states.clone().to("cpu") if self.skip_upload: continue tar_file_name = get_tar_file_name(__url__) # It is not strictly clear to me if it is necessary to lock this whole block or # just part(s) of the kickout/create new writer. Just lock the whole function to be # safe. self.open_lock.acquire() if tar_file_name not in self.uploads: if len(self.uploads) == 5: # kick out the earliest one key = next(iter(self.uploads.keys())) self.uploads[key]["writer"].close() del self.uploads[key] upload_command = f"pipe:aws s3 cp - {self.upload_to}/{tar_file_name}" logger.warning(f"opening new writer for {upload_command}") self.uploads[tar_file_name] = { "writer": wds.TarWriter(upload_command), "lock": Lock(), } upload = self.uploads[tar_file_name] self.open_lock.release() metadata = dict(metadata) metadata["attention_mask_length"] = attention_mask_length sample = { "__key__": __key__, PAELLA_F8_VQVAE_EXT: encoded_image_f8, VQGAN_F16_VQVAE_EXT: encoded_image_f16, CLIP_EXT: encoder_hidden_states, "json": metadata, } # Not locking around the write will corrupt the tar file upload["lock"].acquire() upload["writer"].write(sample) upload["lock"].release() def distribute_shards(start_shard_all, end_shard_all, slurm_ntasks): total_shards = end_shard_all - start_shard_all + 1 shards_per_task = total_shards // slurm_ntasks shards_per_task = [shards_per_task] * slurm_ntasks # to distribute the remainder of tasks for non-evenly divisible number of shards left_over_shards = total_shards % slurm_ntasks for slurm_procid in range(left_over_shards): shards_per_task[slurm_procid] += 1 assert sum(shards_per_task) == total_shards distributed_shards = [] for slurm_procid in range(len(shards_per_task)): if slurm_procid == 0: start_shard = start_shard_all else: start_shard = distributed_shards[slurm_procid - 1][1] + 1 end_shard = start_shard + shards_per_task[slurm_procid] - 1 distributed_shards.append((start_shard, end_shard)) assert sum([end_shard - start_shard + 1 for start_shard, end_shard in distributed_shards]) == total_shards return distributed_shards def main(): parser = argparse.ArgumentParser() parser.add_argument( "--dataset", type=str, help="The dataset to pre-encode", choices=["laion_5", "laion_6", "coyo"], required=True, ) parser.add_argument( "--start_shard", type=int, help="The starting shard to pre-encode.", required=True, ) parser.add_argument( "--end_shard", type=int, help="The ending shard to pre-encode, inclusive. If not given, defaults to `--start_shard`.", required=False, ) parser.add_argument( "--slurm", action="store_true", help=( "If set, this process is running under a batch of slurm tasks." "`--start_shard` and `--end_shard` must be set for the entirety of shards over all slurm tasks." " The shards that will be encoded in each instance of the task will be determined via" " the env vars `$SLURM_NTASKS` and `$SLURM_PROCID`." ), ) parser.add_argument( "--batch_size", type=int, help="The batch size to encode at a time", required=False, default=160 ) parser.add_argument( "--resolution", type=int, help="The resolution to convert the image to.", required=False, default=256 ) parser.add_argument( "--skip_upload", action="store_true", help="Set to not actually upload results, helpful for only testing encoding.", ) parser.add_argument( "--num_writing_threads", type=int, required=False, default=40, ) args = parser.parse_args() if args.slurm and args.end_shard is None: raise ValueError("`--end_shard` must be set when `--slurm` is set") if args.end_shard is None: args.end_shard = args.start_shard if args.end_shard < args.start_shard: raise ValueError("`--end_shard` must be >= `--start_shard`") if args.batch_size < 1: raise ValueError("`--batch_size` must be >= 1") if args.resolution < 1: raise ValueError("`--resolution` must be >= 1") if args.dataset == "laion_5": args.dataset = LAION_AESTHETICS_V2_5_PLUS elif args.dataset == "laion_6": args.dataset = LAION_AESTHETICS_V2_6_PLUS elif args.dataset == "coyo": args.dataset = COYO else: assert False if args.dataset == LAION_AESTHETICS_V2_5_PLUS: upload_to = LAION_AESTHETICS_V2_5_PLUS_PRE_ENCODED elif args.dataset == LAION_AESTHETICS_V2_6_PLUS: upload_to = LAION_AESTHETICS_V2_6_PLUS_PRE_ENCODED elif args.dataset == COYO: upload_to = COYO_PRE_ENCODED else: assert False logger.warning("********************") logger.warning("Pre-encoding dataset") logger.warning(f"dataset: {args.dataset}") logger.warning(f"start_shard: {args.start_shard}") logger.warning(f"end_shard: {args.end_shard}") logger.warning(f"upload_to: {upload_to}") logger.warning(f"batch_size: {args.batch_size}") logger.warning("********************") if args.slurm: slurm_procid = int(os.environ["SLURM_PROCID"]) slurm_ntasks = int(os.environ["SLURM_NTASKS"]) distributed_shards = distribute_shards(args.start_shard, args.end_shard, slurm_ntasks) start_shard_task, end_shard_task = distributed_shards[slurm_procid] args.start_shard = start_shard_task args.end_shard = end_shard_task logger.warning("************") logger.warning("Running as slurm task") logger.warning(f"SLURM_NTASKS: {slurm_ntasks}") logger.warning(f"SLURM_PROCID: {slurm_procid}") logger.warning(f"start_shard: {start_shard_task}, end_shard: {end_shard_task}") logger.warning("************") logger.warning(f"all slurm processes") for slurm_proc_id_, (start_shard, end_shard) in enumerate(distributed_shards): logger.warning(f"slurm process: {slurm_proc_id_}, start_shard: {start_shard}, end_shard: {end_shard}") logger.warning("************") vae_f8 = PaellaVQModel.from_pretrained(PAELLA_F8_VQVAE) vae_f8.to("cuda") vae_f8.requires_grad_(False) vae_f16 = VQGANModel.from_pretrained(VQGAN_F16_VQVAE) vae_f16.to("cuda") vae_f16.requires_grad_(False) tokenizer = CLIPTokenizerFast.from_pretrained(CLIP) text_encoder = CLIPTextModel.from_pretrained(CLIP) text_encoder.to_bettertransformer() text_encoder.to("cuda") shard_range = "{" + format_shard_number(args.start_shard) + ".." + format_shard_number(args.end_shard) + "}" download_shards = f"pipe:aws s3 cp {args.dataset}/{shard_range}.tar -" logger.warning(f"downloading shards {download_shards}") src = ( wds.WebDataset( download_shards, ) .decode("pil", handler=wds.warn_and_continue) .rename(image="jpg;png;jpeg;webp", prompt="text;txt;caption", metadata="json") .map( lambda dict: { "__key__": dict["__key__"], "__url__": dict["__url__"], "image": dict["image"], "prompt": dict["prompt"], "metadata": dict["metadata"], } ) .to_tuple("__key__", "__url__", "image", "prompt", "metadata") .batched(args.batch_size) ) src = DataLoader( src, batch_size=None, shuffle=False, num_workers=0, ) with Uploads(args.skip_upload, upload_to, args.num_writing_threads) as uploads: for __key__, __url__, image, prompt, metadata in src: logger.warning(f"Encoding {len(__key__)} examples: {__key__[0]} to {__key__[-1]}.") encoded_prompts = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt") attention_masks = encoded_prompts.attention_mask # attention masks are [1, 1, 1, 1, 0, ....., 0] so summing gives us the # index of last non-zero element. attention_mask_lengths = attention_masks.sum(-1) # Will be stored as a part of json metadata attention_mask_lengths = attention_mask_lengths.tolist() input_ids = encoded_prompts.input_ids.to("cuda") all_images = [] for image_ in image: # The following is minorly more efficient than the default # torchvision to_tensor and lets use move to cuda earlier :P mode = image_.mode height = image_.height width = image_.width if hasattr(image_, "getbands"): channels = len(image_.getbands()) else: channels = image_.channels if mode == "I": nptype = np.int32 elif mode == "I;16": nptype = np.int16 elif mode == "F": nptype = np.float32 else: nptype = np.uint8 image_ = np.array(image_, nptype) image_ = torch.from_numpy(image_) image_: torch.Tensor = image_.to("cuda") image_ = image_.view(height, width, channels) image_ = image_.permute((2, 0, 1)).contiguous() if mode != "1" and image_.dtype == torch.uint8: image_ = image_.to(dtype=torch.float32).div(255) image_ = TF.resize( image_, size=args.resolution, interpolation=InterpolationMode.BILINEAR, antialias=True ) image_ = TF.center_crop(image_, args.resolution) all_images.append(image_) image = torch.stack(all_images) encoder_hidden_states = text_encoder(input_ids)[0] with torch.cuda.amp.autocast(): encoded_image_f8 = vae_f8.get_code(image) with torch.cuda.amp.autocast(): encoded_image_f16 = vae_f16.get_code(image) uploads.submit( __key__, __url__, encoder_hidden_states, attention_mask_lengths, encoded_image_f8, encoded_image_f16, metadata, ) if __name__ == "__main__": main()