def __call__()

in activemri/experimental/cvpr19_models/trainer.py [0:0]


    def __call__(self) -> float:
        self.logger = logging.getLogger()
        if self.options.debug:
            self.logger.setLevel(logging.DEBUG)
        else:
            self.logger.setLevel(logging.INFO)
        fh = logging.FileHandler(
            os.path.join(self.options.checkpoints_dir, "trainer.log")
        )
        formatter = logging.Formatter(
            "%(asctime)s - %(threadName)s - %(levelname)s: %(message)s"
        )
        fh.setFormatter(formatter)
        self.logger.addHandler(fh)

        self.logger.info("Creating trainer with the following options:")
        for key, value in vars(self.options).items():
            if key == "device":
                value = value.type
            elif key == "gpu_ids":
                value = "cuda : " + str(value) if torch.cuda.is_available() else "cpu"
            self.logger.info(f"    {key:>25}: {'None' if value is None else value:<30}")

        # Create Reconstructor Model
        self.reconstructor = models.reconstruction.ReconstructorNetwork(
            number_of_cascade_blocks=self.options.number_of_cascade_blocks,
            n_downsampling=self.options.n_downsampling,
            number_of_filters=self.options.number_of_reconstructor_filters,
            number_of_layers_residual_bottleneck=self.options.number_of_layers_residual_bottleneck,
            mask_embed_dim=self.options.mask_embed_dim,
            dropout_probability=self.options.dropout_probability,
            img_width=self.options.image_width,
            use_deconv=self.options.use_deconv,
        )

        if self.options.device.type == "cuda":
            self.reconstructor = torch.nn.DataParallel(self.reconstructor).to(
                self.options.device
            )
        self.optimizers = {
            "G": optim.Adam(
                self.reconstructor.parameters(),
                lr=self.options.lr,
                betas=(self.options.beta1, 0.999),
            )
        }

        # Create Evaluator Model
        if self.options.use_evaluator:
            self.evaluator = models.evaluator.EvaluatorNetwork(
                number_of_filters=self.options.number_of_evaluator_filters,
                number_of_conv_layers=self.options.number_of_evaluator_convolution_layers,
                use_sigmoid=False,
                width=self.options.image_width,
                height=640 if self.options.dataroot == "KNEE_RAW" else None,
                mask_embed_dim=self.options.mask_embed_dim,
            )
            self.evaluator = torch.nn.DataParallel(self.evaluator).to(
                self.options.device
            )

            self.optimizers["D"] = optim.Adam(
                self.evaluator.parameters(),
                lr=self.options.lr,
                betas=(self.options.beta1, 0.999),
            )

        train_loader, val_loader = self.get_loaders()

        self.load_from_checkpoint_if_present()
        self.load_weights_from_given_checkpoint()

        writer = SummaryWriter(self.options.checkpoints_dir)

        # Training engine and handlers
        train_engine = Engine(lambda engine, batch: self.update(batch))
        val_engine = Engine(lambda engine, batch: self.inference(batch))

        validation_mse = Loss(
            loss_fn=F.mse_loss,
            output_transform=lambda x: (
                x["reconstructed_image_magnitude"],
                x["ground_truth_magnitude"],
            ),
        )
        validation_mse.attach(val_engine, name="mse")

        validation_ssim = Loss(
            loss_fn=util.common.compute_ssims,
            output_transform=lambda x: (
                x["reconstructed_image_magnitude"],
                x["ground_truth_magnitude"],
            ),
        )
        validation_ssim.attach(val_engine, name="ssim")

        if self.options.use_evaluator:
            validation_loss_d = Loss(
                loss_fn=self.discriminator_loss,
                output_transform=lambda x: (
                    x["reconstructor_eval"],
                    x["ground_truth_eval"],
                    {
                        "reconstructed_image": x["reconstructed_image"],
                        "target": x["ground_truth"],
                        "mask": x["mask"],
                    },
                ),
            )
            validation_loss_d.attach(val_engine, name="loss_D")

        progress_bar = ProgressBar()
        progress_bar.attach(train_engine)

        train_engine.add_event_handler(
            Events.EPOCH_COMPLETED,
            run_validation_and_update_best_checkpoint,
            val_engine=val_engine,
            progress_bar=progress_bar,
            val_loader=val_loader,
            trainer=self,
        )

        # Tensorboard Plots
        @train_engine.on(Events.ITERATION_COMPLETED)
        def plot_training_loss(engine):
            writer.add_scalar(
                "training/generator_loss",
                engine.state.output["loss_G"],
                self.updates_performed,
            )
            if "loss_D" in engine.state.output:
                writer.add_scalar(
                    "training/discriminator_loss",
                    engine.state.output["loss_D"],
                    self.updates_performed,
                )

        @train_engine.on(Events.EPOCH_COMPLETED)
        def plot_validation_loss(_):
            writer.add_scalar(
                "validation/MSE", val_engine.state.metrics["mse"], self.completed_epochs
            )
            writer.add_scalar(
                "validation/SSIM",
                val_engine.state.metrics["ssim"],
                self.completed_epochs,
            )
            if "loss_D" in val_engine.state.metrics:
                writer.add_scalar(
                    "validation/loss_D",
                    val_engine.state.metrics["loss_D"],
                    self.completed_epochs,
                )

        @train_engine.on(Events.EPOCH_COMPLETED)
        def plot_validation_images(_):
            ground_truth = val_engine.state.output["ground_truth_magnitude"]
            zero_filled_image = val_engine.state.output["zero_filled_image_magnitude"]
            reconstructed_image = val_engine.state.output[
                "reconstructed_image_magnitude"
            ]
            uncertainty_map = val_engine.state.output["uncertainty_map"]
            difference = torch.abs(ground_truth - reconstructed_image)

            # Create plots
            ground_truth = util.common.create_grid_from_tensor(ground_truth)
            writer.add_image(
                "validation_images/ground_truth", ground_truth, self.completed_epochs
            )

            zero_filled_image = util.common.create_grid_from_tensor(zero_filled_image)
            writer.add_image(
                "validation_images/zero_filled_image",
                zero_filled_image,
                self.completed_epochs,
            )

            reconstructed_image = util.common.create_grid_from_tensor(
                reconstructed_image
            )
            writer.add_image(
                "validation_images/reconstructed_image",
                reconstructed_image,
                self.completed_epochs,
            )

            uncertainty_map = util.common.gray2heatmap(
                util.common.create_grid_from_tensor(uncertainty_map.exp()),
                cmap="jet",
            )
            writer.add_image(
                "validation_images/uncertainty_map",
                uncertainty_map,
                self.completed_epochs,
            )

            difference = util.common.create_grid_from_tensor(difference)
            difference = util.common.gray2heatmap(difference, cmap="gray")
            writer.add_image(
                "validation_images/difference", difference, self.completed_epochs
            )

            mask = util.common.create_grid_from_tensor(
                val_engine.state.output["mask"].repeat(
                    1, 1, val_engine.state.output["mask"].shape[3], 1
                )
            )
            writer.add_image(
                "validation_images/mask_image", mask, self.completed_epochs
            )

        train_engine.add_event_handler(
            Events.EPOCH_COMPLETED,
            save_regular_checkpoint,
            trainer=self,
            progress_bar=progress_bar,
        )

        train_engine.run(train_loader, self.options.max_epochs - self.completed_epochs)

        writer.close()

        return self.best_validation_score