def generate_part()

in generate_creative_creatures.py [0:0]


def generate_part(model, partial_image, partial_rgb, color=None, part_name=20, num=0, num_image_tiles=8, trunc_psi=1., save_img=False, trans_std=2, results_dir='../results/bird_seq_unet_5fold'):
    model.eval()
    ext = 'png'
    num_rows = num_image_tiles
    latent_dim = model.G.latent_dim
    image_size = model.G.image_size
    num_layers = model.G.num_layers
    def translate_image(image, trans_std=2, rot_std=3, scale_std=2):
        affine_image = torch.zeros_like(image)
        side = image.shape[-1]
        x_shift = np.random.normal(0, trans_std)
        y_shift = np.random.normal(0, trans_std)
        theta = np.random.normal(0, rot_std)
        scale = int(np.random.normal(0, scale_std))
        T = np.float32([[1, 0, x_shift], [0, 1, y_shift]]) 
        M = cv2.getRotationMatrix2D((side/2,side/2),theta,1)
        for i in range(image.shape[1]):
            sketch_channel = image[0, i].cpu().data.numpy()
            sketch_translation = cv2.warpAffine(sketch_channel, T, (side, side))
            affine_image[0, i] = torch.cuda.FloatTensor(sketch_translation)
        return affine_image, x_shift, y_shift, theta, scale
    def recover_image(image, x_shift, y_shift, theta, scale):
        x_shift *= -1
        y_shift *= -1
        theta *= -1
        # scale *= -1
        affine_image = torch.zeros_like(image)
        side = image.shape[-1]
        T = np.float32([[1, 0, x_shift], [0, 1, y_shift]]) 
        M = cv2.getRotationMatrix2D((side/2,side/2),theta,1)
        for i in range(image.shape[1]):
            sketch_channel = image[0, i].cpu().data.numpy()
            sketch_translation = cv2.warpAffine(sketch_channel, T, (side, side)) 
            affine_image[0, i] = torch.cuda.FloatTensor(sketch_translation)
        return affine_image

    # latents and noise
    latents_z = noise_list(num_rows ** 2, num_layers, latent_dim)
    n = image_noise(num_rows ** 2, image_size)
    image_partial_batch = partial_image[:, -1:, :, :]
    translated_image, dx, dy, theta, scale = translate_image(partial_image, trans_std=trans_std)
    bitmap_feats = model.Enc(translated_image)
    # bitmap_feats = model.Enc(partial_image)
    # generated_partial_images = generate_truncated(model.S, model.G, latents_z, n, trunc_psi = trunc_psi, bitmap_feats=bitmap_feats)
    generated_partial_images = recover_image(generate_truncated(model.S, model.G, latents_z, n, trunc_psi = trunc_psi, bitmap_feats=bitmap_feats), dx, dy, theta, scale)
    # post process
    generated_partial_rgb = gs_to_rgb(generated_partial_images, color)
    generated_images = generated_partial_images + image_partial_batch
    generated_rgb = 1 - ((1-generated_partial_rgb)+(1-partial_rgb))
    if save_img:
        torchvision.utils.save_image(generated_partial_rgb, os.path.join(results_dir, f'{str(num)}-{part_name}-comp.{ext}'), nrow=num_rows)
        torchvision.utils.save_image(generated_rgb, os.path.join(results_dir, f'{str(num)}-{part_name}.{ext}'), nrow=num_rows)
    return generated_partial_images.clamp_(0., 1.), generated_images.clamp_(0., 1.), generated_partial_rgb.clamp_(0., 1.), generated_rgb.clamp_(0., 1.)