in generate_creative_creatures.py [0:0]
def generate_part(model, partial_image, partial_rgb, color=None, part_name=20, num=0, num_image_tiles=8, trunc_psi=1., save_img=False, trans_std=2, results_dir='../results/bird_seq_unet_5fold'):
model.eval()
ext = 'png'
num_rows = num_image_tiles
latent_dim = model.G.latent_dim
image_size = model.G.image_size
num_layers = model.G.num_layers
def translate_image(image, trans_std=2, rot_std=3, scale_std=2):
affine_image = torch.zeros_like(image)
side = image.shape[-1]
x_shift = np.random.normal(0, trans_std)
y_shift = np.random.normal(0, trans_std)
theta = np.random.normal(0, rot_std)
scale = int(np.random.normal(0, scale_std))
T = np.float32([[1, 0, x_shift], [0, 1, y_shift]])
M = cv2.getRotationMatrix2D((side/2,side/2),theta,1)
for i in range(image.shape[1]):
sketch_channel = image[0, i].cpu().data.numpy()
sketch_translation = cv2.warpAffine(sketch_channel, T, (side, side))
affine_image[0, i] = torch.cuda.FloatTensor(sketch_translation)
return affine_image, x_shift, y_shift, theta, scale
def recover_image(image, x_shift, y_shift, theta, scale):
x_shift *= -1
y_shift *= -1
theta *= -1
# scale *= -1
affine_image = torch.zeros_like(image)
side = image.shape[-1]
T = np.float32([[1, 0, x_shift], [0, 1, y_shift]])
M = cv2.getRotationMatrix2D((side/2,side/2),theta,1)
for i in range(image.shape[1]):
sketch_channel = image[0, i].cpu().data.numpy()
sketch_translation = cv2.warpAffine(sketch_channel, T, (side, side))
affine_image[0, i] = torch.cuda.FloatTensor(sketch_translation)
return affine_image
# latents and noise
latents_z = noise_list(num_rows ** 2, num_layers, latent_dim)
n = image_noise(num_rows ** 2, image_size)
image_partial_batch = partial_image[:, -1:, :, :]
translated_image, dx, dy, theta, scale = translate_image(partial_image, trans_std=trans_std)
bitmap_feats = model.Enc(translated_image)
# bitmap_feats = model.Enc(partial_image)
# generated_partial_images = generate_truncated(model.S, model.G, latents_z, n, trunc_psi = trunc_psi, bitmap_feats=bitmap_feats)
generated_partial_images = recover_image(generate_truncated(model.S, model.G, latents_z, n, trunc_psi = trunc_psi, bitmap_feats=bitmap_feats), dx, dy, theta, scale)
# post process
generated_partial_rgb = gs_to_rgb(generated_partial_images, color)
generated_images = generated_partial_images + image_partial_batch
generated_rgb = 1 - ((1-generated_partial_rgb)+(1-partial_rgb))
if save_img:
torchvision.utils.save_image(generated_partial_rgb, os.path.join(results_dir, f'{str(num)}-{part_name}-comp.{ext}'), nrow=num_rows)
torchvision.utils.save_image(generated_rgb, os.path.join(results_dir, f'{str(num)}-{part_name}.{ext}'), nrow=num_rows)
return generated_partial_images.clamp_(0., 1.), generated_images.clamp_(0., 1.), generated_partial_rgb.clamp_(0., 1.), generated_rgb.clamp_(0., 1.)