in activemri/experimental/cvpr19_models/models/fft_utils.py [0:0]
def preprocess_inputs(batch, dataroot, device, prev_reconstruction=None):
mask = batch[0].to(device)
target = batch[1].to(device)
if dataroot == "KNEE_RAW":
k_space = batch[2].permute(0, 3, 1, 2).to(device)
# alter mask to always include the highest frequencies that include padding
mask = torch.where(
to_magnitude(k_space).sum(2).unsqueeze(2) == 0.0,
torch.tensor(1.0).to(device),
mask,
)
if prev_reconstruction is None:
masked_true_k_space = torch.where(
mask.byte(), k_space, torch.tensor(0.0).to(device)
)
else:
prev_reconstruction = prev_reconstruction.clone()
prev_reconstruction[:, :, :160, :] = 0
prev_reconstruction[:, :, -160:, :] = 0
prev_reconstruction[:, :, :, :24] = 0
prev_reconstruction[:, :, :, -24:] = 0
ft_x = fft(prev_reconstruction, shift=True)
masked_true_k_space = torch.where(mask.byte(), k_space, ft_x)
reconstructor_input = ifft(masked_true_k_space, ifft_shift=True)
target = target.permute(0, 3, 1, 2)
else:
fft_target = fft(target)
if prev_reconstruction is None:
masked_true_k_space = torch.where(
mask.byte(), fft_target, torch.tensor(0.0).to(device)
)
else:
ft_x = fft(prev_reconstruction)
masked_true_k_space = torch.where(mask.byte(), fft_target, ft_x)
reconstructor_input = ifft(masked_true_k_space)
return reconstructor_input, target, mask