scripts/log_generations_wandb.py (65 lines of code) (raw):
import json
from argparse import ArgumentParser
from itertools import islice
import torch
import wandb
from muse import PipelineMuse
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def generate_and_log(args):
run_name = f"{args.transformer} samples at checkpoint {args.checkpoint}"
wandb.init(
project=args.project,
entity=args.entity,
name=run_name,
notes=(
f"Samples from {args.run_id} at checkpoint {args.checkpoint} with timesteps={args.timesteps},"
f" guidance_scale={args.guidance_scale}, temperature={args.temperature}"
),
)
pipe = PipelineMuse.from_pretrained(
text_encoder_path=args.text_encoder,
vae_path=args.vae,
transformer_path=args.transformer,
).to(device=args.device)
pipe.transformer.enable_xformers_memory_efficient_attention()
# open args.prompts_file_path and read prompts in a list
with open(args.prompts_file_path, "r") as f:
prompts = f.readlines()
# divide the prompts into batches of size args.batch_size
prompts = list(chunk(prompts, args.batch_size))
# generate images and log in wandb table
table = wandb.Table(columns=["prompt"] + [f"image {i}" for i in range(args.num_generations)])
for batch in prompts:
images = pipe(
batch,
timesteps=args.timesteps,
guidance_scale=args.guidance_scale,
temperature=args.temperature,
num_images_per_prompt=args.num_generations,
use_maskgit_generate=True,
use_fp16=True,
)
# create rows like this: [prompt, image 1, image 2, ...]
# where each image is a wandb.Image
# and log in wandb table
images = list(chunk(images, args.num_generations))
for prompt, gen_images in zip(batch, images):
row = [prompt]
for image in gen_images:
row.append(wandb.Image(image))
table.add_data(*row)
wandb.log({"samples": table})
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--project", type=str, default="muse")
parser.add_argument("--entity", type=str, default="psuraj")
parser.add_argument("--run_id", type=str, required=True)
parser.add_argument("--timesteps", type=int, default=12)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--guidance_scale", type=float, default=8)
parser.add_argument("--num_generations", type=int, default=8)
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--text_encoder", type=str, default="google/t5-v1_1-large")
parser.add_argument("--vae", type=str, default="openMUSE/maskgit-vqgan-imagenet-f16-256")
parser.add_argument("--transformer", type=str, required=True)
parser.add_argument("--prompts_file_path", type=str, required=True)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--batch_size", type=int, default=64)
args = parser.parse_args()
generate_and_log(args)