in scripts/gen_sdxl_synthetic_dataset.py [0:0]
def main():
args = argparse.ArgumentParser()
args.add_argument("--slurm", action="store_true")
args.add_argument("--n_shards_to_write", required=True, type=int)
args = args.parse_args()
if args.slurm:
slurm_procid = int(os.environ["SLURM_PROCID"])
# `1 +` because we already used caption shard 0 while doing testing
caption_shard_n = 1 + slurm_procid
else:
caption_shard_n = 0
device = "cuda"
clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16)
clip.to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16"
)
pipe.load_lora_weights(
"stabilityai/stable-diffusion-xl-base-1.0",
weight_name="sd_xl_offset_example-lora_1.0.safetensors",
)
pipe.to(device)
pipe.fuse_lora(lora_scale=0.4)
pipe.enable_xformers_memory_efficient_attention()
pipe.vae.enable_slicing()
refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=pipe.text_encoder_2,
vae=pipe.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
).to("cuda")
captions = get_captions(caption_shard_n)
captions = take_up_to(captions, 500)
for shard_n, captions_ in enumerate(captions):
t0 = time.perf_counter()
logger.warning(f"shard_n {shard_n}")
writer = wds.TarWriter(
"pipe:aws s3 cp -"
f" s3://muse-datasets/sdxl-synthetic-dataset/{caption_shard_n}/{format_shard_number(shard_n)}.tar"
)
key = 0
for caption_batch_idx, captions__ in enumerate(split_list(captions_, 8)):
logger.warning(f"caption_batch_idx {caption_batch_idx}")
num_inference_steps = 35
num_images_per_prompt = 4
proportion_base_model = 0.8
images = pipe(
prompt=captions__,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images_per_prompt,
denoising_end=proportion_base_model,
output_type="latent",
).images
images = refiner(
prompt=captions__,
num_inference_steps=num_inference_steps,
denoising_start=proportion_base_model,
image=images,
num_images_per_prompt=num_images_per_prompt,
).images
for caption, images_ in zip(captions__, split_list(images, 4)):
# TODO - can we avoid syncing images to cpu
input = clip_processor(text=caption, images=images_, return_tensors="pt", padding="max_length", max_length=77, truncation=True)
input["pixel_values"] = input["pixel_values"].to(dtype=torch.float16, device=device)
input["input_ids"] = input["input_ids"].to(device)
input["attention_mask"] = input["attention_mask"].to(device)
clip_scores = clip(**input).logits_per_image.flatten().tolist()
clip_scores = [str(x) for x in clip_scores]
clip_scores = ",".join(clip_scores)
logger.warning(f"__key__ {key}")
writer.write(
{
"__key__": format_shard_number(key),
"0.png": images_[0],
"1.png": images_[1],
"2.png": images_[2],
"3.png": images_[3],
"txt": caption,
"clip_scores.txt": clip_scores,
}
)
key += 1
writer.close()
logger.warning(f"shard_n {shard_n} {time.perf_counter() - t0}")
if shard_n + 1 > args.n_shards_to_write:
break