def main()

in training/generate_images.py [0:0]


def main(args):
    prompts = [
        f"A chihuahua in {args.style_descriptor} style",
        f"A tabby cat in {args.style_descriptor} style",
        f"A portrait of chihuahua in {args.style_descriptor} style",
        f"An apple on the table in {args.style_descriptor} style",
        f"A banana on the table in {args.style_descriptor} style",
        f"A church on the street in {args.style_descriptor} style",
        f"A church in the mountain in {args.style_descriptor} style",
        f"A church in the field in {args.style_descriptor} style",
        f"A church on the beach in {args.style_descriptor} style",
        f"A chihuahua walking on the street in {args.style_descriptor} style",
        f"A tabby cat walking on the street in {args.style_descriptor} style",
        f"A portrait of tabby cat in {args.style_descriptor} style",
        f"An apple on the dish in {args.style_descriptor} style",
        f"A banana on the dish in {args.style_descriptor} style",
        f"A human walking on the street in {args.style_descriptor} style",
        f"A temple on the street in {args.style_descriptor} style",
        f"A temple in the mountain in {args.style_descriptor} style",
        f"A temple in the field in {args.style_descriptor} style",
        f"A temple on the beach in {args.style_descriptor} style",
        f"A chihuahua walking in the forest in {args.style_descriptor} style",
        f"A tabby cat walking in the forest in {args.style_descriptor} style",
        f"A portrait of human face in {args.style_descriptor} style",
        f"An apple on the ground in {args.style_descriptor} style",
        f"A banana on the ground in {args.style_descriptor} style",
        f"A human walking in the forest in {args.style_descriptor} style",
        f"A cabin on the street in {args.style_descriptor} style",
        f"A cabin in the mountain in {args.style_descriptor} style",
        f"A cabin in the field in {args.style_descriptor} style",
        f"A cabin on the beach in {args.style_descriptor} style"
    ]

    logger.warning(f"generating image for {prompts}")

    logger.warning(f"loading models")

    pipe_args = {}

    if args.load_transformer_from is not None:
        pipe_args["transformer"] = UVit2DModel.from_pretrained(args.load_transformer_from)
    
    pipe = AmusedPipeline.from_pretrained(
        pretrained_model_name_or_path=args.pretrained_model_name_or_path,
        revision=args.revision, 
        variant=args.variant,
        **pipe_args
    )

    if args.load_transformer_lora_from is not None:
        pipe.transformer = PeftModel.from_pretrained(
            pipe.transformer, os.path.join(args.load_transformer_from), is_trainable=False
        )

    pipe.to(args.device)

    logger.warning(f"generating images")

    os.makedirs(args.write_images_to, exist_ok=True)

    for prompt_idx in range(0, len(prompts), args.batch_size):
        images = pipe(prompts[prompt_idx:prompt_idx+args.batch_size]).images

        for image_idx, image in enumerate(images):
            prompt = prompts[prompt_idx+image_idx]
            image.save(os.path.join(args.write_images_to, prompt + ".png"))