def main()

in inference/generate_images.py [0:0]


def main(test_config):
    suffix = (
        "_nofeataug"
        if test_config["resolution"] == 256
        and test_config["trained_dataset"] == "imagenet"
        else ""
    )
    exp_name = "%s_%s_%s_res%i%s" % (
        test_config["model"],
        test_config["model_backbone"],
        test_config["trained_dataset"],
        test_config["resolution"],
        suffix,
    )
    device = "cuda"
    ### -- Data -- ###
    data, transform_list = get_data(
        test_config["root_path"],
        test_config["model"],
        test_config["resolution"],
        test_config["which_dataset"],
        test_config["visualize_instance_images"],
    )

    ### -- Model -- ###
    generator = get_model(
        exp_name, test_config["root_path"], test_config["model_backbone"], device=device
    )

    ### -- Generate images -- ###
    # Prepare input and conditioning: different noise vector per sample but the same conditioning
    # Sample noise vector
    z, all_feats, all_labels, all_img_paths = get_conditionings(
        test_config, generator, data
    )

    ## Generate the images
    all_generated_images = []
    with torch.no_grad():
        num_batches = 1 + (z.shape[0]) // test_config["batch_size"]
        for i in range(num_batches):
            start = test_config["batch_size"] * i
            end = min(
                test_config["batch_size"] * i + test_config["batch_size"], z.shape[0]
            )
            if all_labels is not None:
                labels_ = all_labels[start:end].to(device)
            else:
                labels_ = None
            gen_img = generator(
                z[start:end].to(device), labels_, all_feats[start:end].to(device)
            )
            if test_config["model_backbone"] == "biggan":
                gen_img = ((gen_img * 0.5 + 0.5) * 255).int()
            elif test_config["model_backbone"] == "stylegan2":
                gen_img = torch.clamp((gen_img * 127.5 + 128), 0, 255).int()
            all_generated_images.append(gen_img.cpu())
    all_generated_images = torch.cat(all_generated_images)
    all_generated_images = all_generated_images.permute(0, 2, 3, 1).numpy()

    big_plot = []
    for i in range(0, test_config["num_conditionings_gen"]):
        row = []
        for j in range(0, test_config["num_imgs_gen"]):
            subplot_idx = (i * test_config["num_imgs_gen"]) + j
            row.append(all_generated_images[subplot_idx])
        row = np.concatenate(row, axis=1)
        big_plot.append(row)
    big_plot = np.concatenate(big_plot, axis=0)

    # (Optional) Show ImageNet ground-truth conditioning instances
    if test_config["visualize_instance_images"]:
        all_gt_imgs = []
        for i in range(0, len(all_img_paths)):
            all_gt_imgs.append(
                np.array(
                    transform_list(
                        pil_loader(
                            os.path.join(test_config["dataset_path"], all_img_paths[i])
                        )
                    )
                ).astype(np.uint8)
            )
        all_gt_imgs = np.concatenate(all_gt_imgs, axis=0)
        white_space = (
            np.ones((all_gt_imgs.shape[0], 20, all_gt_imgs.shape[2])) * 255
        ).astype(np.uint8)
        big_plot = np.concatenate([all_gt_imgs, white_space, big_plot], axis=1)

    plt.figure(
        figsize=(
            5 * test_config["num_imgs_gen"],
            5 * test_config["num_conditionings_gen"],
        )
    )
    plt.imshow(big_plot)
    plt.axis("off")

    fig_path = "%s_Generations_with_InstanceDataset_%s%s%s_zvar%0.2f.png" % (
        exp_name,
        test_config["which_dataset"],
        "_index" + str(test_config["index"])
        if test_config["index"] is not None
        else "",
        "_class_idx" + str(test_config["swap_target"])
        if test_config["swap_target"] is not None
        else "",
        test_config["z_var"],
    )
    plt.savefig(fig_path, dpi=600, bbox_inches="tight", pad_inches=0)

    print("Done! Figure saved as %s" % (fig_path))