in projects/variational_image_compression/lightning/_prior_autoencoder.py [0:0]
def training_step(self, *args, **kwargs) -> Tensor:
batch: Tensor
batch_idx: int
optimizer_idx: int
batch, batch_idx, optimizer_idx = args
if optimizer_idx == 0:
x_hat, likelihoods = self(batch)
rate, distortion, rate_distortion = self.rate_distortion_loss(
x_hat,
batch,
likelihoods,
)
dictionary = {
"rate": rate.item(),
"distortion": distortion.item(),
"rate_distortion": rate_distortion.item(),
}
self.log_dict(dictionary, sync_dist=True)
return rate_distortion
else:
bottleneck_loss = self.bottleneck_loss()
self.log(
"bottleneck_loss",
bottleneck_loss.item(),
sync_dist=True,
)
return bottleneck_loss