def preprocess_inputs()

in activemri/experimental/cvpr19_models/models/fft_utils.py [0:0]


def preprocess_inputs(batch, dataroot, device, prev_reconstruction=None):
    mask = batch[0].to(device)
    target = batch[1].to(device)
    if dataroot == "KNEE_RAW":
        k_space = batch[2].permute(0, 3, 1, 2).to(device)
        # alter mask to always include the highest frequencies that include padding
        mask = torch.where(
            to_magnitude(k_space).sum(2).unsqueeze(2) == 0.0,
            torch.tensor(1.0).to(device),
            mask,
        )
        if prev_reconstruction is None:
            masked_true_k_space = torch.where(
                mask.byte(), k_space, torch.tensor(0.0).to(device)
            )
        else:
            prev_reconstruction = prev_reconstruction.clone()
            prev_reconstruction[:, :, :160, :] = 0
            prev_reconstruction[:, :, -160:, :] = 0
            prev_reconstruction[:, :, :, :24] = 0
            prev_reconstruction[:, :, :, -24:] = 0
            ft_x = fft(prev_reconstruction, shift=True)
            masked_true_k_space = torch.where(mask.byte(), k_space, ft_x)
        reconstructor_input = ifft(masked_true_k_space, ifft_shift=True)
        target = target.permute(0, 3, 1, 2)
    else:
        fft_target = fft(target)
        if prev_reconstruction is None:
            masked_true_k_space = torch.where(
                mask.byte(), fft_target, torch.tensor(0.0).to(device)
            )
        else:
            ft_x = fft(prev_reconstruction)
            masked_true_k_space = torch.where(mask.byte(), fft_target, ft_x)

        reconstructor_input = ifft(masked_true_k_space)

    return reconstructor_input, target, mask