scripts/log_inpainting_images.py (80 lines of code) (raw):
import copy
import json
import os
from argparse import ArgumentParser
from itertools import islice
import numpy as np
import torch
import wandb
from PIL import Image
from muse import PipelineMuseInpainting
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def generate_and_log(args):
os.makedirs(args.output_dir, exist_ok=True)
vae_scaling_factor = args.vae_scaling_factor
pipe = PipelineMuseInpainting.from_pretrained(
model_name_or_path=args.model_name_or_path,
is_class_conditioned=args.is_class_conditioned,
).to(device=args.device)
pipe.transformer.enable_xformers_memory_efficient_attention()
if args.is_class_conditioned:
imagenet_class_ids = [args.imagenet_class_id]
class_ids = torch.tensor(imagenet_class_ids).to(device=args.device, dtype=torch.long)
inputs = {"class_ids": class_ids}
else:
inputs = {"text": args.text}
mask = np.zeros((args.image_size // vae_scaling_factor, args.image_size // vae_scaling_factor))
mask[args.mask_start_x : args.mask_end_x, args.mask_start_y : args.mask_end_y] = 1
mask = mask.reshape(-1)
mask = torch.tensor(mask).to(args.device, dtype=torch.bool)
image = Image.open(args.input_image).resize((args.image_size, args.image_size))
masked_image = copy.deepcopy(np.array(image))
masked_image[
args.mask_start_x * vae_scaling_factor : args.mask_end_x * vae_scaling_factor,
args.mask_start_y * vae_scaling_factor : args.mask_end_y * vae_scaling_factor,
] = 0
masked_image = Image.fromarray(masked_image)
masked_image.save(os.path.join(args.output_dir, "segmented.jpg"))
images = pipe(
image=image,
mask=mask,
**inputs,
timesteps=args.timesteps,
guidance_scale=args.guidance_scale,
temperature=args.temperature,
use_maskgit_generate=not args.not_maskgit_generate,
num_images_per_prompt=args.num_generations,
image_size=args.image_size,
)
if args.is_class_conditioned:
images = list(chunk(images, args.num_generations))
for class_id, class_images in zip(imagenet_class_ids, images):
for i, image in enumerate(class_images):
image.save(os.path.join(args.output_dir, f"output_{class_id}_{i}.jpg"))
else:
for i, image in enumerate(images):
image.save(os.path.join(args.output_dir, f"output_{i}.jpg"))
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--is_class_conditioned", action="store_true")
parser.add_argument("--timesteps", type=int, default=18)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--guidance_scale", type=float, default=2.0)
parser.add_argument("--not_maskgit_generate", action="store_true")
parser.add_argument("--num_generations", type=int, default=8)
parser.add_argument("--model_name_or_path", type=str, default="openMUSE/muse-laiona6-uvit-clip-220k")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--imagenet_class_id", type=int, default=248)
parser.add_argument("--text", type=str, default="a picture of a dog")
parser.add_argument("--input_image", type=str, required=True)
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--mask_start_x", type=int, default=4)
parser.add_argument("--mask_start_y", type=int, default=4)
parser.add_argument("--mask_end_x", type=int, default=12)
parser.add_argument("--mask_end_y", type=int, default=12)
parser.add_argument("--vae_scaling_factor", type=int, default=16)
parser.add_argument("--output_dir", type=str, default="generated")
args = parser.parse_args()
generate_and_log(args)