in activemri/experimental/cvpr19_models/models/reconstruction.py [0:0]
def forward(self, zero_filled_input, mask):
"""Generates reconstructions given images with partial k-space info.
Args:
zero_filled_input(torch.Tensor): Image obtained from zero-filled reconstruction
of partial k-space scans.
mask(torch.Tensor): Mask used in creating the zero filled image from ground truth
image.
Returns:
tuple(torch.Tensor, torch.Tensor, torch.Tensor): Contains:\n
* Reconstructed high resolution image.
* Uncertainty map.
* Mask_embedding.
"""
if self.use_mask_embedding:
mask_embedding = self.embed_mask(mask)
mask_embedding = mask_embedding.repeat(
1, 1, zero_filled_input.shape[2], zero_filled_input.shape[3]
)
encoder_input = torch.cat([zero_filled_input, mask_embedding], 1)
else:
encoder_input = zero_filled_input
mask_embedding = None
residual_bottleneck_output = None
for cascade_block, (encoder, residual_bottleneck, decoder) in enumerate(
zip(
self.encoders_all_cascade_blocks,
self.residual_bottlenecks_all_cascade_blocks,
self.decoders_all_cascade_blocks,
)
):
encoder_output = encoder(encoder_input)
if cascade_block > 0:
# Skip connection from previous residual block
encoder_output = encoder_output + residual_bottleneck_output
residual_bottleneck_output = residual_bottleneck(encoder_output)
decoder_output = decoder(residual_bottleneck_output)
reconstructed_image = self.data_consistency(
decoder_output[:, :-1, ...], zero_filled_input, mask
)
uncertainty_map = decoder_output[:, -1:, :, :]
if self.use_mask_embedding:
encoder_input = torch.cat([reconstructed_image, mask_embedding], 1)
else:
encoder_input = reconstructed_image
return reconstructed_image, uncertainty_map, mask_embedding