def evaluate()

in part_generator.py [0:0]


    def evaluate(self, num = 0, num_image_tiles = 8, trunc = 1.0, rgb = False):
        self.GAN.eval()
        ext = 'png'
        num_rows = num_image_tiles
    
        # latent_dim = self.GAN.G.latent_dim - self.GAN.Enc.feat_dim
        latent_dim = self.GAN.G.latent_dim
        image_size = self.GAN.G.image_size
        num_layers = self.GAN.G.num_layers

        # latents and noise

        latents_z = noise_list(num_rows ** 2, num_layers, latent_dim)
        n = image_noise(num_rows ** 2, image_size)

        image_batch, image_cond_batch, part_only_batch = [item.cuda() for item in self.dataset_G.sample_partial_test(num_rows ** 2)]
        image_partial_batch = image_cond_batch[:, -1:, :, :] # take the first one as the entire input partial sketch

        # concat the two latent vectors
        bitmap_feats = self.GAN.Enc(image_cond_batch)

        generated_partial_images = self.generate_truncated(self.GAN.S, self.GAN.G, latents_z, n, trunc_psi = self.trunc_psi, bitmap_feats=bitmap_feats)
        generated_images = torch.max(generated_partial_images, image_partial_batch)
        
        if not rgb:
            torchvision.utils.save_image(image_partial_batch, str(self.results_dir / self.name / f'{str(num)}-part.{ext}'), nrow=num_rows)
            # torchvision.utils.save_image((image_batch-image_partial_batch).clamp_(0., 1.), str(self.results_dir / self.name / f'{str(num)}-real.{ext}'), nrow=num_rows)
            torchvision.utils.save_image(part_only_batch, str(self.results_dir / self.name / f'{str(num)}-real.{ext}'), nrow=num_rows)
            torchvision.utils.save_image(image_batch, str(self.results_dir / self.name / f'{str(num)}-full.{ext}'), nrow=num_rows)
            # regular
            torchvision.utils.save_image(generated_partial_images, str(self.results_dir / self.name / f'{str(num)}-comp.{ext}'), nrow=num_rows)
            torchvision.utils.save_image(generated_images.clamp_(0., 1.), str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows)
        else:
            # part_batch = (image_batch-image_partial_batch).clamp_(0., 1.)
            partial_rgb = gs_to_rgb(image_partial_batch, self.default_color)
            # part_rgb = gs_to_rgb(part_batch, self.color)
            part_rgb = gs_to_rgb(part_only_batch, self.color)
            torchvision.utils.save_image(partial_rgb, str(self.results_dir / self.name / f'{str(num)}-part.{ext}'), nrow=num_rows)
            torchvision.utils.save_image(part_rgb, str(self.results_dir / self.name / f'{str(num)}-real.{ext}'), nrow=num_rows)
            torchvision.utils.save_image(1-((1-part_rgb)+(1-partial_rgb).clamp_(0., 1.)), str(self.results_dir / self.name / f'{str(num)}-full.{ext}'), nrow=num_rows)
            # regular
            generated_part_rgb = gs_to_rgb(generated_partial_images, self.color)
            torchvision.utils.save_image(generated_part_rgb, str(self.results_dir / self.name / f'{str(num)}-comp.{ext}'), nrow=num_rows)
            torchvision.utils.save_image(1-((1-generated_part_rgb)+(1-partial_rgb).clamp_(0., 1.)), str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows)