def generate_and_log()

in scripts/log_inpainting_images.py [0:0]


def generate_and_log(args):
    os.makedirs(args.output_dir, exist_ok=True)
    vae_scaling_factor = args.vae_scaling_factor
    pipe = PipelineMuseInpainting.from_pretrained(
        model_name_or_path=args.model_name_or_path,
        is_class_conditioned=args.is_class_conditioned,
    ).to(device=args.device)
    pipe.transformer.enable_xformers_memory_efficient_attention()

    if args.is_class_conditioned:
        imagenet_class_ids = [args.imagenet_class_id]
        class_ids = torch.tensor(imagenet_class_ids).to(device=args.device, dtype=torch.long)
        inputs = {"class_ids": class_ids}
    else:
        inputs = {"text": args.text}

    mask = np.zeros((args.image_size // vae_scaling_factor, args.image_size // vae_scaling_factor))
    mask[args.mask_start_x : args.mask_end_x, args.mask_start_y : args.mask_end_y] = 1
    mask = mask.reshape(-1)
    mask = torch.tensor(mask).to(args.device, dtype=torch.bool)

    image = Image.open(args.input_image).resize((args.image_size, args.image_size))

    masked_image = copy.deepcopy(np.array(image))
    masked_image[
        args.mask_start_x * vae_scaling_factor : args.mask_end_x * vae_scaling_factor,
        args.mask_start_y * vae_scaling_factor : args.mask_end_y * vae_scaling_factor,
    ] = 0
    masked_image = Image.fromarray(masked_image)
    masked_image.save(os.path.join(args.output_dir, "segmented.jpg"))
    images = pipe(
        image=image,
        mask=mask,
        **inputs,
        timesteps=args.timesteps,
        guidance_scale=args.guidance_scale,
        temperature=args.temperature,
        use_maskgit_generate=not args.not_maskgit_generate,
        num_images_per_prompt=args.num_generations,
        image_size=args.image_size,
    )

    if args.is_class_conditioned:
        images = list(chunk(images, args.num_generations))
        for class_id, class_images in zip(imagenet_class_ids, images):
            for i, image in enumerate(class_images):
                image.save(os.path.join(args.output_dir, f"output_{class_id}_{i}.jpg"))
    else:
        for i, image in enumerate(images):
            image.save(os.path.join(args.output_dir, f"output_{i}.jpg"))