def compute_batch_loss()

in projects/deep_video_compression/_utils.py [0:0]


    def compute_batch_loss(self, image1, image2):
        assert self.loss_functions.entropy_fn is not None
        assert image1.ndim == image2.ndim == 4

        output = self.forward(image1, image2)

        # compute distortion loss
        distortion_loss = self.loss_functions.distortion_fn(image2, output.image2_est)

        # compute flow compression loss, average over num pixels
        num_pixels = image1.shape[0] * image1.shape[-2] * image1.shape[-1]
        flow_entropy_loss = self.loss_functions.entropy_fn(
            output.flow_probabilities, num_pixels
        )

        # compute resid compression loss, average over num pixels
        resid_entropy_loss = self.loss_functions.entropy_fn(
            output.resid_probabilities, num_pixels
        )

        return (
            LossValues(
                distortion_loss=distortion_loss,
                flow_entropy_loss=flow_entropy_loss,
                resid_entropy_loss=resid_entropy_loss,
            ),
            OutputTensors(
                flow=output.flow.detach(),
                image1=image1.detach(),
                image2=image2.detach(),
                image2_est=output.image2_est.detach(),
            ),
        )