in activemri/experimental/cvpr19_models/models/evaluator.py [0:0]
def forward(self, reconstructed_image, mask_embedding, mask):
batch_size = reconstructed_image.shape[0]
height = reconstructed_image.shape[2]
width = reconstructed_image.shape[3]
# create spectral maps in kspace
kspace = fft_utils.fft(reconstructed_image)
kspace = kspace.unsqueeze(1).repeat(1, width, 1, 1, 1)
# separate image into spectral maps
separate_mask = torch.zeros([1, width, 1, 1, width], dtype=torch.float32)
for i in range(width):
separate_mask[0, i, 0, 0, i] = 1
separate_mask = separate_mask.to(reconstructed_image.device)
masked_kspace = torch.where(
separate_mask.byte(), kspace, torch.tensor(0.0).to(kspace.device)
)
masked_kspace = masked_kspace.view(batch_size * width, 2, height, width)
# convert spectral maps to image space
separate_images = fft_utils.ifft(masked_kspace)
# result is (batch, [real_M0, img_M0, real_M1, img_M1, ...], height, width]
separate_images = separate_images.contiguous().view(
batch_size, 2, width, height, width
)
# add mask information as a summation -- might not be optimal
if mask is not None:
separate_images = (
separate_images + mask.permute(0, 3, 1, 2).unsqueeze(1).detach()
)
separate_images = separate_images.contiguous().view(
batch_size, 2 * width, height, width
)
# concatenate mask embedding
if mask_embedding is not None:
spectral_map = torch.cat([separate_images, mask_embedding], dim=1)
else:
spectral_map = separate_images
return spectral_map