in projects/scale_hyperprior_lightning/scale_hyperprior.py [0:0]
def training_step(self, batch, batch_idx, optimizer_idx):
if optimizer_idx not in [0, 1]:
raise ValueError(
f"Received unexpected optimizer index {optimizer_idx}"
" - should be 0 or 1"
)
if optimizer_idx == 0:
x_hat, y_likelihoods, z_likelihoods = self(batch)
bpp_loss, distortion_loss, combined_loss = self.rate_distortion_loss(
x_hat, y_likelihoods, z_likelihoods, batch
)
self.log_dict(
{
"bpp_loss": bpp_loss.item(),
"distortion_loss": distortion_loss.item(),
"loss": combined_loss.item(),
},
sync_dist=True,
)
return combined_loss
else:
# This is the loss for learning the quantiles of the
# distribution for the hyperprior.
quantile_loss = self.model.quantile_loss()
self.log("quantile_loss", quantile_loss.item(), sync_dist=True)
return quantile_loss