def on_epoch_end()

in torchrecipes/vision/image_generation/callbacks/image_generation.py [0:0]


    def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        if not (logger := trainer.logger):
            rank_zero_warn("Trainer must have a logger configured.")
            return
        if not (experiment := logger.experiment):
            rank_zero_warn("Trainer must have a logger configured that can log images.")
            return

        dim = (self.num_samples, pl_module.latent_dim)
        z = torch.normal(mean=0.0, std=1.0, size=dim, device=pl_module.device)

        # generate images
        with torch.no_grad(), mode(pl_module, training=False) as eval_module:
            images = eval_module(z)

        img_dim = pl_module.img_dim
        images = images.view(self.num_samples, *img_dim)

        grid = torchvision.utils.make_grid(
            tensor=images,
            nrow=self.nrow,
            padding=self.padding,
            normalize=self.normalize,
            range=self.norm_range,
            scale_each=self.scale_each,
            pad_value=self.pad_value,
        )
        str_title = f"{pl_module.__class__.__name__}_images"
        experiment.add_image(str_title, grid, global_step=trainer.global_step)