def training_step()

in projects/variational_image_compression/lightning/_prior_autoencoder.py [0:0]


    def training_step(self, *args, **kwargs) -> Tensor:
        batch: Tensor
        batch_idx: int
        optimizer_idx: int

        batch, batch_idx, optimizer_idx = args

        if optimizer_idx == 0:
            x_hat, likelihoods = self(batch)

            rate, distortion, rate_distortion = self.rate_distortion_loss(
                x_hat,
                batch,
                likelihoods,
            )

            dictionary = {
                "rate": rate.item(),
                "distortion": distortion.item(),
                "rate_distortion": rate_distortion.item(),
            }

            self.log_dict(dictionary, sync_dist=True)

            return rate_distortion
        else:
            bottleneck_loss = self.bottleneck_loss()

            self.log(
                "bottleneck_loss",
                bottleneck_loss.item(),
                sync_dist=True,
            )

            return bottleneck_loss