in generate_creative_birds.py [0:0]
def generate_part(model, partial_image, partial_rgb, color=None, percentage=20, num=0, num_image_tiles=8, trunc_psi=1., save_img=False, results_dir='../results', evolvement=False):
model.eval()
ext = 'png'
num_rows = np.sqrt(num_image_tiles)
latent_dim = model.G.latent_dim
image_size = model.G.image_size
num_layers = model.G.num_layers
if percentage == 'eye':
n_eye = 10
generated_partial_images_candidates = []
scores = torch.zeros(n_eye)
for _ in range(n_eye):
latents_z = noise_list(num_image_tiles, num_layers, latent_dim)
n = image_noise(num_image_tiles, image_size)
image_partial_batch = partial_image[:, -1:, :, :]
bitmap_feats = model.Enc(partial_image)
generated_partial_images = generate_truncated(model.S, model.G, latents_z, n, trunc_psi = trunc_psi, bitmap_feats=bitmap_feats)
generated_partial_images_candidates.append(generated_partial_images)
generated_partial_images_candidates = torch.cat(generated_partial_images_candidates, 0)
# eye size rank
n_pixels = generated_partial_images_candidates.sum(-1).sum(-1).sum(-1) # B
for rank, i_eye in enumerate(torch.argsort(n_pixels, descending=True)):
scores[i_eye] += (rank+1)/n_eye
# eye distance rank
initial_stroke = partial_image[:, :1].cpu().data.numpy()
initial_stroke_dt = torch.cuda.FloatTensor(distance_transform_edt(1-initial_stroke))
dt_pixels = (generated_partial_images_candidates*initial_stroke_dt).sum(-1).sum(-1).sum(-1) # B
for rank, i_eye in enumerate(torch.argsort(dt_pixels, descending=False)): # the smaller the better
if n_pixels[i_eye] > 3:
scores[i_eye] += (rank+1)/n_eye
generated_partial_images = generated_partial_images_candidates[torch.argsort(scores, descending=True)[0]].unsqueeze(0)
else:
# latents and noise
latents_z = noise_list(num_image_tiles, num_layers, latent_dim)
n = image_noise(num_image_tiles, image_size)
image_partial_batch = partial_image[:, -1:, :, :]
bitmap_feats = model.Enc(partial_image)
generated_partial_images = generate_truncated(model.S, model.G, latents_z, n, trunc_psi = trunc_psi, bitmap_feats=bitmap_feats)
# regular
generated_partial_images = generate_truncated(model.S, model.G, latents_z, n, trunc_psi = trunc_psi, bitmap_feats=bitmap_feats)
generated_partial_rgb = gs_to_rgb(generated_partial_images, color)
generated_images = generated_partial_images + image_partial_batch
generated_rgb = 1 - ((1-generated_partial_rgb)+(1-partial_rgb))
if save_img:
torchvision.utils.save_image(generated_partial_rgb, os.path.join(results_dir, f'{str(num)}-{percentage}-comp.{ext}'), nrow=num_rows)
torchvision.utils.save_image(generated_rgb, os.path.join(results_dir, f'{str(num)}-{percentage}.{ext}'), nrow=num_rows)
return generated_partial_images.clamp_(0., 1.), generated_images.clamp_(0., 1.), generated_partial_rgb.clamp_(0., 1.), generated_rgb.clamp_(0., 1.)