in part_generator.py [0:0]
def evaluate(self, num = 0, num_image_tiles = 8, trunc = 1.0, rgb = False):
self.GAN.eval()
ext = 'png'
num_rows = num_image_tiles
# latent_dim = self.GAN.G.latent_dim - self.GAN.Enc.feat_dim
latent_dim = self.GAN.G.latent_dim
image_size = self.GAN.G.image_size
num_layers = self.GAN.G.num_layers
# latents and noise
latents_z = noise_list(num_rows ** 2, num_layers, latent_dim)
n = image_noise(num_rows ** 2, image_size)
image_batch, image_cond_batch, part_only_batch = [item.cuda() for item in self.dataset_G.sample_partial_test(num_rows ** 2)]
image_partial_batch = image_cond_batch[:, -1:, :, :] # take the first one as the entire input partial sketch
# concat the two latent vectors
bitmap_feats = self.GAN.Enc(image_cond_batch)
generated_partial_images = self.generate_truncated(self.GAN.S, self.GAN.G, latents_z, n, trunc_psi = self.trunc_psi, bitmap_feats=bitmap_feats)
generated_images = torch.max(generated_partial_images, image_partial_batch)
if not rgb:
torchvision.utils.save_image(image_partial_batch, str(self.results_dir / self.name / f'{str(num)}-part.{ext}'), nrow=num_rows)
# torchvision.utils.save_image((image_batch-image_partial_batch).clamp_(0., 1.), str(self.results_dir / self.name / f'{str(num)}-real.{ext}'), nrow=num_rows)
torchvision.utils.save_image(part_only_batch, str(self.results_dir / self.name / f'{str(num)}-real.{ext}'), nrow=num_rows)
torchvision.utils.save_image(image_batch, str(self.results_dir / self.name / f'{str(num)}-full.{ext}'), nrow=num_rows)
# regular
torchvision.utils.save_image(generated_partial_images, str(self.results_dir / self.name / f'{str(num)}-comp.{ext}'), nrow=num_rows)
torchvision.utils.save_image(generated_images.clamp_(0., 1.), str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows)
else:
# part_batch = (image_batch-image_partial_batch).clamp_(0., 1.)
partial_rgb = gs_to_rgb(image_partial_batch, self.default_color)
# part_rgb = gs_to_rgb(part_batch, self.color)
part_rgb = gs_to_rgb(part_only_batch, self.color)
torchvision.utils.save_image(partial_rgb, str(self.results_dir / self.name / f'{str(num)}-part.{ext}'), nrow=num_rows)
torchvision.utils.save_image(part_rgb, str(self.results_dir / self.name / f'{str(num)}-real.{ext}'), nrow=num_rows)
torchvision.utils.save_image(1-((1-part_rgb)+(1-partial_rgb).clamp_(0., 1.)), str(self.results_dir / self.name / f'{str(num)}-full.{ext}'), nrow=num_rows)
# regular
generated_part_rgb = gs_to_rgb(generated_partial_images, self.color)
torchvision.utils.save_image(generated_part_rgb, str(self.results_dir / self.name / f'{str(num)}-comp.{ext}'), nrow=num_rows)
torchvision.utils.save_image(1-((1-generated_part_rgb)+(1-partial_rgb).clamp_(0., 1.)), str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows)