in activemri/experimental/cvpr19_models/trainer.py [0:0]
def update(self, batch):
if not self.options.only_evaluator:
self.reconstructor.train()
(zero_filled_image, target, mask,) = models.fft_utils.preprocess_inputs(
batch, self.options.dataroot, self.options.device
)
# Get reconstructor output
reconstructed_image, uncertainty_map, mask_embedding = self.reconstructor(
zero_filled_image, mask
)
# ------------------------------------------------------------------------
# Update evaluator and compute generator GAN Loss
# ------------------------------------------------------------------------
loss_G_GAN = 0
loss_D = torch.tensor(0.0)
if self.evaluator is not None:
self.evaluator.train()
self.optimizers["D"].zero_grad()
fake = reconstructed_image
detached_fake = fake.detach()
if self.options.mask_embed_dim != 0:
mask_embedding = mask_embedding.detach()
output = self.evaluator(
detached_fake,
mask_embedding,
mask if self.options.add_mask_eval else None,
)
loss_D_fake = self.losses["GAN"](
output, False, mask, degree=0, pred_and_gt=(detached_fake, target)
)
real = target
output = self.evaluator(
real, mask_embedding, mask if self.options.add_mask_eval else None
)
loss_D_real = self.losses["GAN"](
output, True, mask, degree=1, pred_and_gt=(detached_fake, target)
)
loss_D = loss_D_fake + loss_D_real
loss_D.backward(retain_graph=True)
self.optimizers["D"].step()
if not self.options.only_evaluator:
output = self.evaluator(
fake, mask_embedding, mask if self.options.add_mask_eval else None
)
loss_G_GAN = self.losses["GAN"](
output,
True,
mask,
degree=1,
updateG=True,
pred_and_gt=(fake, target),
)
loss_G_GAN *= self.options.lambda_gan
# ------------------------------------------------------------------------
# Update reconstructor
# ------------------------------------------------------------------------
loss_G = torch.tensor(0.0)
if not self.options.only_evaluator:
self.optimizers["G"].zero_grad()
loss_G = self.losses["NLL"](
reconstructed_image, target, uncertainty_map, self.options
).mean()
loss_G += loss_G_GAN
loss_G.backward()
self.optimizers["G"].step()
self.updates_performed += 1
return {"loss_D": loss_D.item(), "loss_G": loss_G.item()}