in inference/generate_images.py [0:0]
def main(test_config):
suffix = (
"_nofeataug"
if test_config["resolution"] == 256
and test_config["trained_dataset"] == "imagenet"
else ""
)
exp_name = "%s_%s_%s_res%i%s" % (
test_config["model"],
test_config["model_backbone"],
test_config["trained_dataset"],
test_config["resolution"],
suffix,
)
device = "cuda"
### -- Data -- ###
data, transform_list = get_data(
test_config["root_path"],
test_config["model"],
test_config["resolution"],
test_config["which_dataset"],
test_config["visualize_instance_images"],
)
### -- Model -- ###
generator = get_model(
exp_name, test_config["root_path"], test_config["model_backbone"], device=device
)
### -- Generate images -- ###
# Prepare input and conditioning: different noise vector per sample but the same conditioning
# Sample noise vector
z, all_feats, all_labels, all_img_paths = get_conditionings(
test_config, generator, data
)
## Generate the images
all_generated_images = []
with torch.no_grad():
num_batches = 1 + (z.shape[0]) // test_config["batch_size"]
for i in range(num_batches):
start = test_config["batch_size"] * i
end = min(
test_config["batch_size"] * i + test_config["batch_size"], z.shape[0]
)
if all_labels is not None:
labels_ = all_labels[start:end].to(device)
else:
labels_ = None
gen_img = generator(
z[start:end].to(device), labels_, all_feats[start:end].to(device)
)
if test_config["model_backbone"] == "biggan":
gen_img = ((gen_img * 0.5 + 0.5) * 255).int()
elif test_config["model_backbone"] == "stylegan2":
gen_img = torch.clamp((gen_img * 127.5 + 128), 0, 255).int()
all_generated_images.append(gen_img.cpu())
all_generated_images = torch.cat(all_generated_images)
all_generated_images = all_generated_images.permute(0, 2, 3, 1).numpy()
big_plot = []
for i in range(0, test_config["num_conditionings_gen"]):
row = []
for j in range(0, test_config["num_imgs_gen"]):
subplot_idx = (i * test_config["num_imgs_gen"]) + j
row.append(all_generated_images[subplot_idx])
row = np.concatenate(row, axis=1)
big_plot.append(row)
big_plot = np.concatenate(big_plot, axis=0)
# (Optional) Show ImageNet ground-truth conditioning instances
if test_config["visualize_instance_images"]:
all_gt_imgs = []
for i in range(0, len(all_img_paths)):
all_gt_imgs.append(
np.array(
transform_list(
pil_loader(
os.path.join(test_config["dataset_path"], all_img_paths[i])
)
)
).astype(np.uint8)
)
all_gt_imgs = np.concatenate(all_gt_imgs, axis=0)
white_space = (
np.ones((all_gt_imgs.shape[0], 20, all_gt_imgs.shape[2])) * 255
).astype(np.uint8)
big_plot = np.concatenate([all_gt_imgs, white_space, big_plot], axis=1)
plt.figure(
figsize=(
5 * test_config["num_imgs_gen"],
5 * test_config["num_conditionings_gen"],
)
)
plt.imshow(big_plot)
plt.axis("off")
fig_path = "%s_Generations_with_InstanceDataset_%s%s%s_zvar%0.2f.png" % (
exp_name,
test_config["which_dataset"],
"_index" + str(test_config["index"])
if test_config["index"] is not None
else "",
"_class_idx" + str(test_config["swap_target"])
if test_config["swap_target"] is not None
else "",
test_config["z_var"],
)
plt.savefig(fig_path, dpi=600, bbox_inches="tight", pad_inches=0)
print("Done! Figure saved as %s" % (fig_path))