def inference()

in activemri/experimental/cvpr19_models/trainer.py [0:0]


    def inference(self, batch):
        self.reconstructor.eval()

        with torch.no_grad():
            (
                zero_filled_image,
                ground_truth,
                mask,
            ) = models.fft_utils.preprocess_inputs(
                batch, self.options.dataroot, self.options.device
            )

            # Get reconstructor output
            reconstructed_image, uncertainty_map, mask_embedding = self.reconstructor(
                zero_filled_image, mask
            )

            reconstructor_eval = None
            ground_truth_eval = None
            if self.evaluator is not None:
                self.evaluator.eval()
                reconstructor_eval = self.evaluator(
                    reconstructed_image, mask_embedding, mask
                )
                ground_truth_eval = self.evaluator(ground_truth, mask_embedding, mask)

            # Compute magnitude (for val losses and plots)
            zero_filled_image_magnitude = models.fft_utils.to_magnitude(
                zero_filled_image
            )
            reconstructed_image_magnitude = models.fft_utils.to_magnitude(
                reconstructed_image
            )
            ground_truth_magnitude = models.fft_utils.to_magnitude(ground_truth)

            if self.options.dataroot == "KNEE_RAW":  # crop data
                reconstructed_image_magnitude = models.fft_utils.center_crop(
                    reconstructed_image_magnitude, [320, 320]
                )
                ground_truth_magnitude = models.fft_utils.center_crop(
                    ground_truth_magnitude, [320, 320]
                )
                zero_filled_image_magnitude = models.fft_utils.center_crop(
                    zero_filled_image_magnitude, [320, 320]
                )
                uncertainty_map = models.fft_utils.center_crop(
                    uncertainty_map, [320, 320]
                )

            return {
                "ground_truth": ground_truth,
                "zero_filled_image": zero_filled_image,
                "reconstructed_image": reconstructed_image,
                "ground_truth_magnitude": ground_truth_magnitude,
                "zero_filled_image_magnitude": zero_filled_image_magnitude,
                "reconstructed_image_magnitude": reconstructed_image_magnitude,
                "uncertainty_map": uncertainty_map,
                "mask": mask,
                "reconstructor_eval": reconstructor_eval,
                "ground_truth_eval": ground_truth_eval,
            }