def info_content_weight_map()

in src/util/IW_SSIM_PyTorch.py [0:0]


    def info_content_weight_map(self, imgopr, imgdpr):

        tol = 1e-15
        iw_map = {}
        for scale in range(1, self.Nsc):

            imgo = imgopr[scale]
            imgd = imgdpr[scale]
            win = np.ones([self.blSzX, self.blSzY])
            win = win / np.sum(win)
            win = torch.from_numpy(win).unsqueeze(0).unsqueeze(0).type(self.samplet.type())
            padding = int((self.blSzX-1)/2)

            # Prepare for estimating IW-SSIM parameters
            mean_x = F.conv2d(imgo, win, padding=padding)
            mean_y = F.conv2d(imgd, win, padding=padding)
            cov_xy = F.conv2d(imgo*imgd, win, padding=padding) - mean_x*mean_y
            ss_x = F.conv2d(imgo**2, win, padding=padding) - mean_x**2
            ss_y = F.conv2d(imgd**2, win, padding=padding) - mean_y**2

            ss_x[ss_x < 0] = 0
            ss_y[ss_y < 0] = 0

            # Estimate gain factor and error
            g = cov_xy / (ss_x + tol)
            vv = (ss_y - g*cov_xy)
            g[ss_x < tol] = 0
            vv[ss_x < tol] = ss_y[ss_x < tol]
            ss_x[ss_x < tol] = 0
            g[ss_y < tol] = 0
            vv[ss_y < tol] = 0

            # Prepare parent band
            aux = imgo
            _, _, Nsy, Nsx = aux.shape
            prnt = (self.parent and scale < self.Nsc-1)
            BL = torch.zeros([1, 1, aux.shape[2], aux.shape[3], 1+prnt])
            if self.use_cuda:
                BL = BL.cuda()
            if self.use_double:
                BL = BL.double()

            BL[:, :, :, :, 0] = aux
            if prnt:
                auxp = imgopr[scale+1]
                auxp = self.imenlarge2(auxp)
                BL[:, :, :, :, 1] = auxp[:, :, 0:Nsy, 0:Nsx]
            imgo = BL
            _, _, nv, nh, nb = imgo.shape

            block = torch.tensor([win.shape[2], win.shape[3]])
            if self.use_cuda:
                block = block.cuda()

            # Group neighboring pixels
            nblv = nv-block[0]+1
            nblh = nh-block[1]+1
            nexp = nblv*nblh
            N = torch.prod(block) + prnt
            Ly = int((block[0]-1)//2)
            Lx = int((block[1]-1)//2)
            Y = torch.zeros([nexp, N]).type(self.samplet.type())

            n = -1
            for ny in range(-Ly, Ly+1):
                for nx in range(-Lx, Lx+1):
                    n = n + 1
                    temp = imgo[0, 0, :, :, 0]
                    foo1 = self.roll(temp, ny, 0)
                    foo = self.roll(foo1, nx, 1)
                    foo = foo[Ly: Ly+nblv, Lx: Lx+nblh]
                    Y[:, n] = foo.flatten()
            if prnt:
                n = n + 1
                temp = imgo[0, 0, :, :, 1]
                foo = temp
                foo = foo[Ly: Ly+nblv, Lx: Lx+nblh]
                Y[:, n] = foo.flatten()

            C_u = torch.mm(torch.transpose(Y, 0, 1), Y) / nexp.type(self.samplet.type())
            eig_values, H = torch.eig(C_u, eigenvectors=True)
            eig_values = eig_values.type(self.samplet.type())
            H = H.type(self.samplet.type())
            if self.use_double:
                L = torch.diag(eig_values[:, 0] * (eig_values[:, 0] > 0).double()) * torch.sum(eig_values) / ((torch.sum(eig_values[:,0] * (eig_values[:, 0] > 0).double())) + (torch.sum(eig_values[:, 0] * (eig_values[:, 0] > 0).double())==0))
            else:
                L = torch.diag(eig_values[:, 0] * (eig_values[:, 0] > 0).float()) * torch.sum(eig_values) / ((torch.sum(eig_values[:,0] * (eig_values[:, 0] > 0).float())) + (torch.sum(eig_values[:, 0] * (eig_values[:, 0] > 0).float())==0))
            C_u = torch.mm(torch.mm(H, L), torch.transpose(H, 0, 1))
            C_u_inv = torch.inverse(C_u)
            ss = (torch.mm(Y, C_u_inv))*Y / N.type(self.samplet.type())
            ss = torch.sum(ss, 1)
            ss = ss.view(nblv, nblh)
            ss = ss.unsqueeze(0).unsqueeze(0)
            g = g[:, :, Ly: Ly+nblv, Lx: Lx+nblh]
            vv = vv[:, :, Ly: Ly+nblv, Lx: Lx+nblh]

            # Calculate mutual information
            infow = torch.zeros(g.shape).type(self.samplet.type())
            for j in range(len(eig_values)):
                infow = infow + torch.log2(1 + ((vv + (1 + g*g)*self.sigma_nsq)*ss*eig_values[j, 0]+self.sigma_nsq*vv) / (self.sigma_nsq*self.sigma_nsq))
            infow[infow < tol] = 0
            iw_map[scale] = infow

        return iw_map