def training_step()

in projects/scale_hyperprior_lightning/scale_hyperprior.py [0:0]


    def training_step(self, batch, batch_idx, optimizer_idx):

        if optimizer_idx not in [0, 1]:
            raise ValueError(
                f"Received unexpected optimizer index {optimizer_idx}"
                " - should be 0 or 1"
            )

        if optimizer_idx == 0:
            x_hat, y_likelihoods, z_likelihoods = self(batch)
            bpp_loss, distortion_loss, combined_loss = self.rate_distortion_loss(
                x_hat, y_likelihoods, z_likelihoods, batch
            )
            self.log_dict(
                {
                    "bpp_loss": bpp_loss.item(),
                    "distortion_loss": distortion_loss.item(),
                    "loss": combined_loss.item(),
                },
                sync_dist=True,
            )
            return combined_loss

        else:
            # This is the loss for learning the quantiles of the
            # distribution for the hyperprior.
            quantile_loss = self.model.quantile_loss()
            self.log("quantile_loss", quantile_loss.item(), sync_dist=True)
            return quantile_loss