def update()

in activemri/experimental/cvpr19_models/trainer.py [0:0]


    def update(self, batch):
        if not self.options.only_evaluator:
            self.reconstructor.train()

        (zero_filled_image, target, mask,) = models.fft_utils.preprocess_inputs(
            batch, self.options.dataroot, self.options.device
        )

        # Get reconstructor output
        reconstructed_image, uncertainty_map, mask_embedding = self.reconstructor(
            zero_filled_image, mask
        )

        # ------------------------------------------------------------------------
        # Update evaluator and compute generator GAN Loss
        # ------------------------------------------------------------------------
        loss_G_GAN = 0
        loss_D = torch.tensor(0.0)
        if self.evaluator is not None:
            self.evaluator.train()
            self.optimizers["D"].zero_grad()
            fake = reconstructed_image
            detached_fake = fake.detach()
            if self.options.mask_embed_dim != 0:
                mask_embedding = mask_embedding.detach()
            output = self.evaluator(
                detached_fake,
                mask_embedding,
                mask if self.options.add_mask_eval else None,
            )
            loss_D_fake = self.losses["GAN"](
                output, False, mask, degree=0, pred_and_gt=(detached_fake, target)
            )

            real = target
            output = self.evaluator(
                real, mask_embedding, mask if self.options.add_mask_eval else None
            )
            loss_D_real = self.losses["GAN"](
                output, True, mask, degree=1, pred_and_gt=(detached_fake, target)
            )

            loss_D = loss_D_fake + loss_D_real
            loss_D.backward(retain_graph=True)
            self.optimizers["D"].step()

            if not self.options.only_evaluator:
                output = self.evaluator(
                    fake, mask_embedding, mask if self.options.add_mask_eval else None
                )
                loss_G_GAN = self.losses["GAN"](
                    output,
                    True,
                    mask,
                    degree=1,
                    updateG=True,
                    pred_and_gt=(fake, target),
                )
                loss_G_GAN *= self.options.lambda_gan

        # ------------------------------------------------------------------------
        # Update reconstructor
        # ------------------------------------------------------------------------
        loss_G = torch.tensor(0.0)
        if not self.options.only_evaluator:
            self.optimizers["G"].zero_grad()
            loss_G = self.losses["NLL"](
                reconstructed_image, target, uncertainty_map, self.options
            ).mean()
            loss_G += loss_G_GAN
            loss_G.backward()
            self.optimizers["G"].step()

        self.updates_performed += 1

        return {"loss_D": loss_D.item(), "loss_G": loss_G.item()}