scripts/gen_sdxl_synthetic_dataset.py (135 lines of code) (raw):
import logging
import time
import os
import argparse
import pyarrow.parquet as pq
import torch
import webdataset as wds
from diffusers import DiffusionPipeline
from huggingface_hub import HfFileSystem
from transformers import CLIPModel, CLIPProcessor
torch.set_float32_matmul_precision("high")
torch.set_grad_enabled(False)
logger = logging.getLogger(__name__)
def main():
args = argparse.ArgumentParser()
args.add_argument("--slurm", action="store_true")
args.add_argument("--n_shards_to_write", required=True, type=int)
args = args.parse_args()
if args.slurm:
slurm_procid = int(os.environ["SLURM_PROCID"])
# `1 +` because we already used caption shard 0 while doing testing
caption_shard_n = 1 + slurm_procid
else:
caption_shard_n = 0
device = "cuda"
clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16)
clip.to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16"
)
pipe.load_lora_weights(
"stabilityai/stable-diffusion-xl-base-1.0",
weight_name="sd_xl_offset_example-lora_1.0.safetensors",
)
pipe.to(device)
pipe.fuse_lora(lora_scale=0.4)
pipe.enable_xformers_memory_efficient_attention()
pipe.vae.enable_slicing()
refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=pipe.text_encoder_2,
vae=pipe.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
).to("cuda")
captions = get_captions(caption_shard_n)
captions = take_up_to(captions, 500)
for shard_n, captions_ in enumerate(captions):
t0 = time.perf_counter()
logger.warning(f"shard_n {shard_n}")
writer = wds.TarWriter(
"pipe:aws s3 cp -"
f" s3://muse-datasets/sdxl-synthetic-dataset/{caption_shard_n}/{format_shard_number(shard_n)}.tar"
)
key = 0
for caption_batch_idx, captions__ in enumerate(split_list(captions_, 8)):
logger.warning(f"caption_batch_idx {caption_batch_idx}")
num_inference_steps = 35
num_images_per_prompt = 4
proportion_base_model = 0.8
images = pipe(
prompt=captions__,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images_per_prompt,
denoising_end=proportion_base_model,
output_type="latent",
).images
images = refiner(
prompt=captions__,
num_inference_steps=num_inference_steps,
denoising_start=proportion_base_model,
image=images,
num_images_per_prompt=num_images_per_prompt,
).images
for caption, images_ in zip(captions__, split_list(images, 4)):
# TODO - can we avoid syncing images to cpu
input = clip_processor(text=caption, images=images_, return_tensors="pt", padding="max_length", max_length=77, truncation=True)
input["pixel_values"] = input["pixel_values"].to(dtype=torch.float16, device=device)
input["input_ids"] = input["input_ids"].to(device)
input["attention_mask"] = input["attention_mask"].to(device)
clip_scores = clip(**input).logits_per_image.flatten().tolist()
clip_scores = [str(x) for x in clip_scores]
clip_scores = ",".join(clip_scores)
logger.warning(f"__key__ {key}")
writer.write(
{
"__key__": format_shard_number(key),
"0.png": images_[0],
"1.png": images_[1],
"2.png": images_[2],
"3.png": images_[3],
"txt": caption,
"clip_scores.txt": clip_scores,
}
)
key += 1
writer.close()
logger.warning(f"shard_n {shard_n} {time.perf_counter() - t0}")
if shard_n + 1 > args.n_shards_to_write:
break
def format_shard_number(shard_n: int):
return "{:0>{}}".format(shard_n, 5)
# A shard has on the order of a million captions, so it's sufficient to just get captions from
# a single shard for a single job. Each job should pass in a unique caption_shard_n
def get_captions(caption_shard_n):
fs = HfFileSystem()
shards = fs.ls("datasets/laion/laion-coco", detail=False)
shard_ctr = 0
found_shard = None
for shard in shards:
if not shard.endswith(".parquet"):
continue
if shard_ctr == caption_shard_n:
found_shard = shard
break
shard_ctr += 1
assert found_shard is not None
with fs.open(found_shard, "rb") as f:
table = pq.read_table(f)
for i in range(len(table[0])):
caption = table[2][i]
yield caption.as_py()
def take_up_to(iterator, n):
iterator = iter(iterator)
iterator_has_elements = True
while iterator_has_elements:
items = []
for _ in range(n):
try:
items.append(next(iterator))
except StopIteration:
iterator_has_elements = False
yield items
def split_list(input_list, chunk_size):
return [input_list[i : i + chunk_size] for i in range(0, len(input_list), chunk_size)]
if __name__ == "__main__":
main()