def forward()

in activemri/experimental/cvpr19_models/models/evaluator.py [0:0]


    def forward(self, reconstructed_image, mask_embedding, mask):
        batch_size = reconstructed_image.shape[0]
        height = reconstructed_image.shape[2]
        width = reconstructed_image.shape[3]

        # create spectral maps in kspace
        kspace = fft_utils.fft(reconstructed_image)
        kspace = kspace.unsqueeze(1).repeat(1, width, 1, 1, 1)

        # separate image into spectral maps
        separate_mask = torch.zeros([1, width, 1, 1, width], dtype=torch.float32)
        for i in range(width):
            separate_mask[0, i, 0, 0, i] = 1

        separate_mask = separate_mask.to(reconstructed_image.device)

        masked_kspace = torch.where(
            separate_mask.byte(), kspace, torch.tensor(0.0).to(kspace.device)
        )
        masked_kspace = masked_kspace.view(batch_size * width, 2, height, width)

        # convert spectral maps to image space
        separate_images = fft_utils.ifft(masked_kspace)
        # result is (batch, [real_M0, img_M0, real_M1, img_M1, ...],  height, width]
        separate_images = separate_images.contiguous().view(
            batch_size, 2, width, height, width
        )

        # add mask information as a summation -- might not be optimal
        if mask is not None:
            separate_images = (
                separate_images + mask.permute(0, 3, 1, 2).unsqueeze(1).detach()
            )

        separate_images = separate_images.contiguous().view(
            batch_size, 2 * width, height, width
        )
        # concatenate mask embedding
        if mask_embedding is not None:
            spectral_map = torch.cat([separate_images, mask_embedding], dim=1)
        else:
            spectral_map = separate_images

        return spectral_map