in sat/vae_modules/autoencoder.py [0:0]
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
x = self.get_input(batch)
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
if hasattr(self.loss, "forward_keys"):
extra_info = {
"z": z,
"optimizer_idx": optimizer_idx,
"global_step": self.global_step,
"last_layer": self.get_last_layer(),
"split": "train",
"regularization_log": regularization_log,
"autoencoder": self,
}
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
if optimizer_idx == 0:
# autoencode
out_loss = self.loss(x, xrec, **extra_info)
if isinstance(out_loss, tuple):
aeloss, log_dict_ae = out_loss
else:
# simple loss function
aeloss = out_loss
log_dict_ae = {"train/loss/rec": aeloss.detach()}
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
sync_dist=False,
)
self.log(
"loss",
aeloss.mean().detach(),
prog_bar=True,
logger=False,
on_epoch=False,
on_step=True,
)
return aeloss
elif optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
# -> discriminator always needs to return a tuple
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
else:
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")