in activemri/experimental/cvpr19_models/models/fft_utils.py [0:0]
def get_target_tensor(self, input, target_is_real, degree, mask, pred_and_gt=None):
if target_is_real:
target_tensor = torch.ones_like(input)
target_tensor[:] = degree
else:
target_tensor = torch.zeros_like(input)
if not self.use_mse_as_energy:
if degree != 1:
target_tensor[:] = degree
else:
pred, gt = pred_and_gt
if self.options.dataroot == "KNEE_RAW":
gt = center_crop(gt, [368, 320])
pred = center_crop(pred, [368, 320])
w = gt.shape[2]
ks_gt = fft(gt, normalized=True)
ks_input = fft(pred, normalized=True)
ks_row_mse = F.mse_loss(ks_input, ks_gt, reduce=False).sum(
1, keepdim=True
).sum(2, keepdim=True).squeeze() / (2 * w)
energy = torch.exp(-ks_row_mse * self.gamma)
target_tensor[:] = energy
# force observed part to always
for i in range(mask.shape[0]):
idx = torch.nonzero(mask[i, 0, 0, :])
target_tensor[i, idx] = 1
return target_tensor