def get_target_tensor()

in activemri/experimental/cvpr19_models/models/fft_utils.py [0:0]


    def get_target_tensor(self, input, target_is_real, degree, mask, pred_and_gt=None):

        if target_is_real:
            target_tensor = torch.ones_like(input)
            target_tensor[:] = degree

        else:
            target_tensor = torch.zeros_like(input)
            if not self.use_mse_as_energy:
                if degree != 1:
                    target_tensor[:] = degree
            else:
                pred, gt = pred_and_gt
                if self.options.dataroot == "KNEE_RAW":
                    gt = center_crop(gt, [368, 320])
                    pred = center_crop(pred, [368, 320])
                w = gt.shape[2]
                ks_gt = fft(gt, normalized=True)
                ks_input = fft(pred, normalized=True)
                ks_row_mse = F.mse_loss(ks_input, ks_gt, reduce=False).sum(
                    1, keepdim=True
                ).sum(2, keepdim=True).squeeze() / (2 * w)
                energy = torch.exp(-ks_row_mse * self.gamma)

                target_tensor[:] = energy
            # force observed part to always
            for i in range(mask.shape[0]):
                idx = torch.nonzero(mask[i, 0, 0, :])
                target_tensor[i, idx] = 1
        return target_tensor