in torchrecipes/vision/image_generation/callbacks/image_generation.py [0:0]
def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
if not (logger := trainer.logger):
rank_zero_warn("Trainer must have a logger configured.")
return
if not (experiment := logger.experiment):
rank_zero_warn("Trainer must have a logger configured that can log images.")
return
dim = (self.num_samples, pl_module.latent_dim)
z = torch.normal(mean=0.0, std=1.0, size=dim, device=pl_module.device)
# generate images
with torch.no_grad(), mode(pl_module, training=False) as eval_module:
images = eval_module(z)
img_dim = pl_module.img_dim
images = images.view(self.num_samples, *img_dim)
grid = torchvision.utils.make_grid(
tensor=images,
nrow=self.nrow,
padding=self.padding,
normalize=self.normalize,
range=self.norm_range,
scale_each=self.scale_each,
pad_value=self.pad_value,
)
str_title = f"{pl_module.__class__.__name__}_images"
experiment.add_image(str_title, grid, global_step=trainer.global_step)