def generate_and_log()

in scripts/log_generations_wandb.py [0:0]


def generate_and_log(args):
    run_name = f"{args.transformer} samples at checkpoint {args.checkpoint}"
    wandb.init(
        project=args.project,
        entity=args.entity,
        name=run_name,
        notes=(
            f"Samples from {args.run_id} at checkpoint {args.checkpoint} with timesteps={args.timesteps},"
            f" guidance_scale={args.guidance_scale}, temperature={args.temperature}"
        ),
    )

    pipe = PipelineMuse.from_pretrained(
        text_encoder_path=args.text_encoder,
        vae_path=args.vae,
        transformer_path=args.transformer,
    ).to(device=args.device)
    pipe.transformer.enable_xformers_memory_efficient_attention()

    # open args.prompts_file_path and read prompts in a list
    with open(args.prompts_file_path, "r") as f:
        prompts = f.readlines()

    # divide the prompts into batches of size args.batch_size
    prompts = list(chunk(prompts, args.batch_size))

    # generate images and log in wandb table
    table = wandb.Table(columns=["prompt"] + [f"image {i}" for i in range(args.num_generations)])
    for batch in prompts:
        images = pipe(
            batch,
            timesteps=args.timesteps,
            guidance_scale=args.guidance_scale,
            temperature=args.temperature,
            num_images_per_prompt=args.num_generations,
            use_maskgit_generate=True,
            use_fp16=True,
        )

        # create rows like this: [prompt, image 1, image 2, ...]
        # where each image is a wandb.Image
        # and log in wandb table
        images = list(chunk(images, args.num_generations))
        for prompt, gen_images in zip(batch, images):
            row = [prompt]
            for image in gen_images:
                row.append(wandb.Image(image))
            table.add_data(*row)

    wandb.log({"samples": table})