in scripts/pre_encode.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
help="The dataset to pre-encode",
choices=["laion_5", "laion_6", "coyo"],
required=True,
)
parser.add_argument(
"--start_shard",
type=int,
help="The starting shard to pre-encode.",
required=True,
)
parser.add_argument(
"--end_shard",
type=int,
help="The ending shard to pre-encode, inclusive. If not given, defaults to `--start_shard`.",
required=False,
)
parser.add_argument(
"--slurm",
action="store_true",
help=(
"If set, this process is running under a batch of slurm tasks."
"`--start_shard` and `--end_shard` must be set for the entirety of shards over all slurm tasks."
" The shards that will be encoded in each instance of the task will be determined via"
" the env vars `$SLURM_NTASKS` and `$SLURM_PROCID`."
),
)
parser.add_argument(
"--batch_size", type=int, help="The batch size to encode at a time", required=False, default=160
)
parser.add_argument(
"--resolution", type=int, help="The resolution to convert the image to.", required=False, default=256
)
parser.add_argument(
"--skip_upload",
action="store_true",
help="Set to not actually upload results, helpful for only testing encoding.",
)
parser.add_argument(
"--num_writing_threads",
type=int,
required=False,
default=40,
)
args = parser.parse_args()
if args.slurm and args.end_shard is None:
raise ValueError("`--end_shard` must be set when `--slurm` is set")
if args.end_shard is None:
args.end_shard = args.start_shard
if args.end_shard < args.start_shard:
raise ValueError("`--end_shard` must be >= `--start_shard`")
if args.batch_size < 1:
raise ValueError("`--batch_size` must be >= 1")
if args.resolution < 1:
raise ValueError("`--resolution` must be >= 1")
if args.dataset == "laion_5":
args.dataset = LAION_AESTHETICS_V2_5_PLUS
elif args.dataset == "laion_6":
args.dataset = LAION_AESTHETICS_V2_6_PLUS
elif args.dataset == "coyo":
args.dataset = COYO
else:
assert False
if args.dataset == LAION_AESTHETICS_V2_5_PLUS:
upload_to = LAION_AESTHETICS_V2_5_PLUS_PRE_ENCODED
elif args.dataset == LAION_AESTHETICS_V2_6_PLUS:
upload_to = LAION_AESTHETICS_V2_6_PLUS_PRE_ENCODED
elif args.dataset == COYO:
upload_to = COYO_PRE_ENCODED
else:
assert False
logger.warning("********************")
logger.warning("Pre-encoding dataset")
logger.warning(f"dataset: {args.dataset}")
logger.warning(f"start_shard: {args.start_shard}")
logger.warning(f"end_shard: {args.end_shard}")
logger.warning(f"upload_to: {upload_to}")
logger.warning(f"batch_size: {args.batch_size}")
logger.warning("********************")
if args.slurm:
slurm_procid = int(os.environ["SLURM_PROCID"])
slurm_ntasks = int(os.environ["SLURM_NTASKS"])
distributed_shards = distribute_shards(args.start_shard, args.end_shard, slurm_ntasks)
start_shard_task, end_shard_task = distributed_shards[slurm_procid]
args.start_shard = start_shard_task
args.end_shard = end_shard_task
logger.warning("************")
logger.warning("Running as slurm task")
logger.warning(f"SLURM_NTASKS: {slurm_ntasks}")
logger.warning(f"SLURM_PROCID: {slurm_procid}")
logger.warning(f"start_shard: {start_shard_task}, end_shard: {end_shard_task}")
logger.warning("************")
logger.warning(f"all slurm processes")
for slurm_proc_id_, (start_shard, end_shard) in enumerate(distributed_shards):
logger.warning(f"slurm process: {slurm_proc_id_}, start_shard: {start_shard}, end_shard: {end_shard}")
logger.warning("************")
vae_f8 = PaellaVQModel.from_pretrained(PAELLA_F8_VQVAE)
vae_f8.to("cuda")
vae_f8.requires_grad_(False)
vae_f16 = VQGANModel.from_pretrained(VQGAN_F16_VQVAE)
vae_f16.to("cuda")
vae_f16.requires_grad_(False)
tokenizer = CLIPTokenizerFast.from_pretrained(CLIP)
text_encoder = CLIPTextModel.from_pretrained(CLIP)
text_encoder.to_bettertransformer()
text_encoder.to("cuda")
shard_range = "{" + format_shard_number(args.start_shard) + ".." + format_shard_number(args.end_shard) + "}"
download_shards = f"pipe:aws s3 cp {args.dataset}/{shard_range}.tar -"
logger.warning(f"downloading shards {download_shards}")
src = (
wds.WebDataset(
download_shards,
)
.decode("pil", handler=wds.warn_and_continue)
.rename(image="jpg;png;jpeg;webp", prompt="text;txt;caption", metadata="json")
.map(
lambda dict: {
"__key__": dict["__key__"],
"__url__": dict["__url__"],
"image": dict["image"],
"prompt": dict["prompt"],
"metadata": dict["metadata"],
}
)
.to_tuple("__key__", "__url__", "image", "prompt", "metadata")
.batched(args.batch_size)
)
src = DataLoader(
src,
batch_size=None,
shuffle=False,
num_workers=0,
)
with Uploads(args.skip_upload, upload_to, args.num_writing_threads) as uploads:
for __key__, __url__, image, prompt, metadata in src:
logger.warning(f"Encoding {len(__key__)} examples: {__key__[0]} to {__key__[-1]}.")
encoded_prompts = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt")
attention_masks = encoded_prompts.attention_mask
# attention masks are [1, 1, 1, 1, 0, ....., 0] so summing gives us the
# index of last non-zero element.
attention_mask_lengths = attention_masks.sum(-1)
# Will be stored as a part of json metadata
attention_mask_lengths = attention_mask_lengths.tolist()
input_ids = encoded_prompts.input_ids.to("cuda")
all_images = []
for image_ in image:
# The following is minorly more efficient than the default
# torchvision to_tensor and lets use move to cuda earlier :P
mode = image_.mode
height = image_.height
width = image_.width
if hasattr(image_, "getbands"):
channels = len(image_.getbands())
else:
channels = image_.channels
if mode == "I":
nptype = np.int32
elif mode == "I;16":
nptype = np.int16
elif mode == "F":
nptype = np.float32
else:
nptype = np.uint8
image_ = np.array(image_, nptype)
image_ = torch.from_numpy(image_)
image_: torch.Tensor = image_.to("cuda")
image_ = image_.view(height, width, channels)
image_ = image_.permute((2, 0, 1)).contiguous()
if mode != "1" and image_.dtype == torch.uint8:
image_ = image_.to(dtype=torch.float32).div(255)
image_ = TF.resize(
image_, size=args.resolution, interpolation=InterpolationMode.BILINEAR, antialias=True
)
image_ = TF.center_crop(image_, args.resolution)
all_images.append(image_)
image = torch.stack(all_images)
encoder_hidden_states = text_encoder(input_ids)[0]
with torch.cuda.amp.autocast():
encoded_image_f8 = vae_f8.get_code(image)
with torch.cuda.amp.autocast():
encoded_image_f16 = vae_f16.get_code(image)
uploads.submit(
__key__,
__url__,
encoder_hidden_states,
attention_mask_lengths,
encoded_image_f8,
encoded_image_f16,
metadata,
)