in projects/deep_video_compression/train.py [0:0]
def log_images(self, trainer, base_key, image_dict, global_step):
# check if we have images to log
# if we do, then concatenate time along x-axis and batch along y-axis
# and write
keys = ("flow", "image1", "image2", "image2_est")
for key in keys:
if image_dict.get(key) is not None:
caption = f"{key} (y-axis: batch, x-axis: time)"
mosaic = torch.cat(image_dict[key], dim=-1)
mosaic = torch.cat(list(mosaic), dim=-2)
if key == "flow":
mosaic = _optical_flow_to_color.optical_flow_to_color(
mosaic.unsqueeze(0)
)[0]
mosaic = torch.clip(mosaic, min=0, max=1.0)
trainer.logger.experiment.log(
{
f"{base_key}/{key}": wandb.Image(mosaic, caption=caption),
"global_step": global_step,
}
)