def log_images()

in projects/deep_video_compression/train.py [0:0]


    def log_images(self, trainer, base_key, image_dict, global_step):
        # check if we have images to log
        # if we do, then concatenate time along x-axis and batch along y-axis
        # and write
        keys = ("flow", "image1", "image2", "image2_est")
        for key in keys:
            if image_dict.get(key) is not None:
                caption = f"{key} (y-axis: batch, x-axis: time)"
                mosaic = torch.cat(image_dict[key], dim=-1)
                mosaic = torch.cat(list(mosaic), dim=-2)
                if key == "flow":
                    mosaic = _optical_flow_to_color.optical_flow_to_color(
                        mosaic.unsqueeze(0)
                    )[0]
                mosaic = torch.clip(mosaic, min=0, max=1.0)
                trainer.logger.experiment.log(
                    {
                        f"{base_key}/{key}": wandb.Image(mosaic, caption=caption),
                        "global_step": global_step,
                    }
                )