def generate_part()

in generate_creative_birds.py [0:0]


def generate_part(model, partial_image, partial_rgb, color=None, percentage=20, num=0, num_image_tiles=8, trunc_psi=1., save_img=False, results_dir='../results', evolvement=False):
    model.eval()
    ext = 'png'
    num_rows = np.sqrt(num_image_tiles)
    latent_dim = model.G.latent_dim
    image_size = model.G.image_size
    num_layers = model.G.num_layers
    if percentage == 'eye':
        n_eye = 10
        generated_partial_images_candidates = []
        scores = torch.zeros(n_eye)
        for _ in range(n_eye):
            latents_z = noise_list(num_image_tiles, num_layers, latent_dim)
            n = image_noise(num_image_tiles, image_size)
            image_partial_batch = partial_image[:, -1:, :, :]
            bitmap_feats = model.Enc(partial_image)
            generated_partial_images = generate_truncated(model.S, model.G, latents_z, n, trunc_psi = trunc_psi, bitmap_feats=bitmap_feats)
            generated_partial_images_candidates.append(generated_partial_images)
        generated_partial_images_candidates = torch.cat(generated_partial_images_candidates, 0)
        # eye size rank
        n_pixels = generated_partial_images_candidates.sum(-1).sum(-1).sum(-1) # B
        for rank, i_eye in enumerate(torch.argsort(n_pixels, descending=True)):
            scores[i_eye] += (rank+1)/n_eye
        # eye distance rank
        initial_stroke = partial_image[:, :1].cpu().data.numpy()
        initial_stroke_dt = torch.cuda.FloatTensor(distance_transform_edt(1-initial_stroke))
        dt_pixels = (generated_partial_images_candidates*initial_stroke_dt).sum(-1).sum(-1).sum(-1) # B
        for rank, i_eye in enumerate(torch.argsort(dt_pixels, descending=False)): # the smaller the better
            if n_pixels[i_eye] > 3:
                scores[i_eye] += (rank+1)/n_eye
        generated_partial_images = generated_partial_images_candidates[torch.argsort(scores, descending=True)[0]].unsqueeze(0)
    else:
        # latents and noise
        latents_z = noise_list(num_image_tiles, num_layers, latent_dim)
        n = image_noise(num_image_tiles, image_size)
        image_partial_batch = partial_image[:, -1:, :, :]
        bitmap_feats = model.Enc(partial_image)
        generated_partial_images = generate_truncated(model.S, model.G, latents_z, n, trunc_psi = trunc_psi, bitmap_feats=bitmap_feats)
    # regular
    generated_partial_images = generate_truncated(model.S, model.G, latents_z, n, trunc_psi = trunc_psi, bitmap_feats=bitmap_feats)
    generated_partial_rgb = gs_to_rgb(generated_partial_images, color)
    generated_images = generated_partial_images + image_partial_batch
    generated_rgb = 1 - ((1-generated_partial_rgb)+(1-partial_rgb))
    if save_img:
        torchvision.utils.save_image(generated_partial_rgb, os.path.join(results_dir, f'{str(num)}-{percentage}-comp.{ext}'), nrow=num_rows)
        torchvision.utils.save_image(generated_rgb, os.path.join(results_dir, f'{str(num)}-{percentage}.{ext}'), nrow=num_rows)
    return generated_partial_images.clamp_(0., 1.), generated_images.clamp_(0., 1.), generated_partial_rgb.clamp_(0., 1.), generated_rgb.clamp_(0., 1.)