in activemri/experimental/cvpr19_models/trainer.py [0:0]
def inference(self, batch):
self.reconstructor.eval()
with torch.no_grad():
(
zero_filled_image,
ground_truth,
mask,
) = models.fft_utils.preprocess_inputs(
batch, self.options.dataroot, self.options.device
)
# Get reconstructor output
reconstructed_image, uncertainty_map, mask_embedding = self.reconstructor(
zero_filled_image, mask
)
reconstructor_eval = None
ground_truth_eval = None
if self.evaluator is not None:
self.evaluator.eval()
reconstructor_eval = self.evaluator(
reconstructed_image, mask_embedding, mask
)
ground_truth_eval = self.evaluator(ground_truth, mask_embedding, mask)
# Compute magnitude (for val losses and plots)
zero_filled_image_magnitude = models.fft_utils.to_magnitude(
zero_filled_image
)
reconstructed_image_magnitude = models.fft_utils.to_magnitude(
reconstructed_image
)
ground_truth_magnitude = models.fft_utils.to_magnitude(ground_truth)
if self.options.dataroot == "KNEE_RAW": # crop data
reconstructed_image_magnitude = models.fft_utils.center_crop(
reconstructed_image_magnitude, [320, 320]
)
ground_truth_magnitude = models.fft_utils.center_crop(
ground_truth_magnitude, [320, 320]
)
zero_filled_image_magnitude = models.fft_utils.center_crop(
zero_filled_image_magnitude, [320, 320]
)
uncertainty_map = models.fft_utils.center_crop(
uncertainty_map, [320, 320]
)
return {
"ground_truth": ground_truth,
"zero_filled_image": zero_filled_image,
"reconstructed_image": reconstructed_image,
"ground_truth_magnitude": ground_truth_magnitude,
"zero_filled_image_magnitude": zero_filled_image_magnitude,
"reconstructed_image_magnitude": reconstructed_image_magnitude,
"uncertainty_map": uncertainty_map,
"mask": mask,
"reconstructor_eval": reconstructor_eval,
"ground_truth_eval": ground_truth_eval,
}