def main()

in scripts/pre_encode.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        help="The dataset to pre-encode",
        choices=["laion_5", "laion_6", "coyo"],
        required=True,
    )
    parser.add_argument(
        "--start_shard",
        type=int,
        help="The starting shard to pre-encode.",
        required=True,
    )
    parser.add_argument(
        "--end_shard",
        type=int,
        help="The ending shard to pre-encode, inclusive. If not given, defaults to `--start_shard`.",
        required=False,
    )
    parser.add_argument(
        "--slurm",
        action="store_true",
        help=(
            "If set, this process is running under a batch of slurm tasks."
            "`--start_shard` and `--end_shard` must be set for the entirety of shards over all slurm tasks."
            " The shards that will be encoded in each instance of the task will be determined via"
            " the env vars `$SLURM_NTASKS` and `$SLURM_PROCID`."
        ),
    )
    parser.add_argument(
        "--batch_size", type=int, help="The batch size to encode at a time", required=False, default=160
    )
    parser.add_argument(
        "--resolution", type=int, help="The resolution to convert the image to.", required=False, default=256
    )
    parser.add_argument(
        "--skip_upload",
        action="store_true",
        help="Set to not actually upload results, helpful for only testing encoding.",
    )
    parser.add_argument(
        "--num_writing_threads",
        type=int,
        required=False,
        default=40,
    )

    args = parser.parse_args()

    if args.slurm and args.end_shard is None:
        raise ValueError("`--end_shard` must be set when `--slurm` is set")

    if args.end_shard is None:
        args.end_shard = args.start_shard

    if args.end_shard < args.start_shard:
        raise ValueError("`--end_shard` must be >= `--start_shard`")

    if args.batch_size < 1:
        raise ValueError("`--batch_size` must be >= 1")

    if args.resolution < 1:
        raise ValueError("`--resolution` must be >= 1")

    if args.dataset == "laion_5":
        args.dataset = LAION_AESTHETICS_V2_5_PLUS
    elif args.dataset == "laion_6":
        args.dataset = LAION_AESTHETICS_V2_6_PLUS
    elif args.dataset == "coyo":
        args.dataset = COYO
    else:
        assert False

    if args.dataset == LAION_AESTHETICS_V2_5_PLUS:
        upload_to = LAION_AESTHETICS_V2_5_PLUS_PRE_ENCODED
    elif args.dataset == LAION_AESTHETICS_V2_6_PLUS:
        upload_to = LAION_AESTHETICS_V2_6_PLUS_PRE_ENCODED
    elif args.dataset == COYO:
        upload_to = COYO_PRE_ENCODED
    else:
        assert False

    logger.warning("********************")
    logger.warning("Pre-encoding dataset")
    logger.warning(f"dataset: {args.dataset}")
    logger.warning(f"start_shard: {args.start_shard}")
    logger.warning(f"end_shard: {args.end_shard}")
    logger.warning(f"upload_to: {upload_to}")
    logger.warning(f"batch_size: {args.batch_size}")
    logger.warning("********************")

    if args.slurm:
        slurm_procid = int(os.environ["SLURM_PROCID"])
        slurm_ntasks = int(os.environ["SLURM_NTASKS"])

        distributed_shards = distribute_shards(args.start_shard, args.end_shard, slurm_ntasks)

        start_shard_task, end_shard_task = distributed_shards[slurm_procid]

        args.start_shard = start_shard_task
        args.end_shard = end_shard_task

        logger.warning("************")
        logger.warning("Running as slurm task")
        logger.warning(f"SLURM_NTASKS: {slurm_ntasks}")
        logger.warning(f"SLURM_PROCID: {slurm_procid}")
        logger.warning(f"start_shard: {start_shard_task}, end_shard: {end_shard_task}")
        logger.warning("************")
        logger.warning(f"all slurm processes")
        for slurm_proc_id_, (start_shard, end_shard) in enumerate(distributed_shards):
            logger.warning(f"slurm process: {slurm_proc_id_}, start_shard: {start_shard}, end_shard: {end_shard}")
        logger.warning("************")

    vae_f8 = PaellaVQModel.from_pretrained(PAELLA_F8_VQVAE)
    vae_f8.to("cuda")
    vae_f8.requires_grad_(False)

    vae_f16 = VQGANModel.from_pretrained(VQGAN_F16_VQVAE)
    vae_f16.to("cuda")
    vae_f16.requires_grad_(False)

    tokenizer = CLIPTokenizerFast.from_pretrained(CLIP)
    text_encoder = CLIPTextModel.from_pretrained(CLIP)
    text_encoder.to_bettertransformer()
    text_encoder.to("cuda")

    shard_range = "{" + format_shard_number(args.start_shard) + ".." + format_shard_number(args.end_shard) + "}"
    download_shards = f"pipe:aws s3 cp {args.dataset}/{shard_range}.tar -"

    logger.warning(f"downloading shards {download_shards}")

    src = (
        wds.WebDataset(
            download_shards,
        )
        .decode("pil", handler=wds.warn_and_continue)
        .rename(image="jpg;png;jpeg;webp", prompt="text;txt;caption", metadata="json")
        .map(
            lambda dict: {
                "__key__": dict["__key__"],
                "__url__": dict["__url__"],
                "image": dict["image"],
                "prompt": dict["prompt"],
                "metadata": dict["metadata"],
            }
        )
        .to_tuple("__key__", "__url__", "image", "prompt", "metadata")
        .batched(args.batch_size)
    )
    src = DataLoader(
        src,
        batch_size=None,
        shuffle=False,
        num_workers=0,
    )

    with Uploads(args.skip_upload, upload_to, args.num_writing_threads) as uploads:
        for __key__, __url__, image, prompt, metadata in src:
            logger.warning(f"Encoding {len(__key__)} examples: {__key__[0]} to {__key__[-1]}.")

            encoded_prompts = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt")

            attention_masks = encoded_prompts.attention_mask
            # attention masks are [1, 1, 1, 1, 0, ....., 0] so summing gives us the
            # index of last non-zero element.
            attention_mask_lengths = attention_masks.sum(-1)
            # Will be stored as a part of json metadata
            attention_mask_lengths = attention_mask_lengths.tolist()

            input_ids = encoded_prompts.input_ids.to("cuda")

            all_images = []

            for image_ in image:
                # The following is minorly more efficient than the default
                # torchvision to_tensor and lets use move to cuda earlier :P
                mode = image_.mode

                height = image_.height
                width = image_.width

                if hasattr(image_, "getbands"):
                    channels = len(image_.getbands())
                else:
                    channels = image_.channels

                if mode == "I":
                    nptype = np.int32
                elif mode == "I;16":
                    nptype = np.int16
                elif mode == "F":
                    nptype = np.float32
                else:
                    nptype = np.uint8

                image_ = np.array(image_, nptype)
                image_ = torch.from_numpy(image_)
                image_: torch.Tensor = image_.to("cuda")

                image_ = image_.view(height, width, channels)
                image_ = image_.permute((2, 0, 1)).contiguous()

                if mode != "1" and image_.dtype == torch.uint8:
                    image_ = image_.to(dtype=torch.float32).div(255)

                image_ = TF.resize(
                    image_, size=args.resolution, interpolation=InterpolationMode.BILINEAR, antialias=True
                )

                image_ = TF.center_crop(image_, args.resolution)

                all_images.append(image_)

            image = torch.stack(all_images)

            encoder_hidden_states = text_encoder(input_ids)[0]

            with torch.cuda.amp.autocast():
                encoded_image_f8 = vae_f8.get_code(image)

            with torch.cuda.amp.autocast():
                encoded_image_f16 = vae_f16.get_code(image)

            uploads.submit(
                __key__,
                __url__,
                encoder_hidden_states,
                attention_mask_lengths,
                encoded_image_f8,
                encoded_image_f16,
                metadata,
            )