in src/util/IW_SSIM_PyTorch.py [0:0]
def info_content_weight_map(self, imgopr, imgdpr):
tol = 1e-15
iw_map = {}
for scale in range(1, self.Nsc):
imgo = imgopr[scale]
imgd = imgdpr[scale]
win = np.ones([self.blSzX, self.blSzY])
win = win / np.sum(win)
win = torch.from_numpy(win).unsqueeze(0).unsqueeze(0).type(self.samplet.type())
padding = int((self.blSzX-1)/2)
# Prepare for estimating IW-SSIM parameters
mean_x = F.conv2d(imgo, win, padding=padding)
mean_y = F.conv2d(imgd, win, padding=padding)
cov_xy = F.conv2d(imgo*imgd, win, padding=padding) - mean_x*mean_y
ss_x = F.conv2d(imgo**2, win, padding=padding) - mean_x**2
ss_y = F.conv2d(imgd**2, win, padding=padding) - mean_y**2
ss_x[ss_x < 0] = 0
ss_y[ss_y < 0] = 0
# Estimate gain factor and error
g = cov_xy / (ss_x + tol)
vv = (ss_y - g*cov_xy)
g[ss_x < tol] = 0
vv[ss_x < tol] = ss_y[ss_x < tol]
ss_x[ss_x < tol] = 0
g[ss_y < tol] = 0
vv[ss_y < tol] = 0
# Prepare parent band
aux = imgo
_, _, Nsy, Nsx = aux.shape
prnt = (self.parent and scale < self.Nsc-1)
BL = torch.zeros([1, 1, aux.shape[2], aux.shape[3], 1+prnt])
if self.use_cuda:
BL = BL.cuda()
if self.use_double:
BL = BL.double()
BL[:, :, :, :, 0] = aux
if prnt:
auxp = imgopr[scale+1]
auxp = self.imenlarge2(auxp)
BL[:, :, :, :, 1] = auxp[:, :, 0:Nsy, 0:Nsx]
imgo = BL
_, _, nv, nh, nb = imgo.shape
block = torch.tensor([win.shape[2], win.shape[3]])
if self.use_cuda:
block = block.cuda()
# Group neighboring pixels
nblv = nv-block[0]+1
nblh = nh-block[1]+1
nexp = nblv*nblh
N = torch.prod(block) + prnt
Ly = int((block[0]-1)//2)
Lx = int((block[1]-1)//2)
Y = torch.zeros([nexp, N]).type(self.samplet.type())
n = -1
for ny in range(-Ly, Ly+1):
for nx in range(-Lx, Lx+1):
n = n + 1
temp = imgo[0, 0, :, :, 0]
foo1 = self.roll(temp, ny, 0)
foo = self.roll(foo1, nx, 1)
foo = foo[Ly: Ly+nblv, Lx: Lx+nblh]
Y[:, n] = foo.flatten()
if prnt:
n = n + 1
temp = imgo[0, 0, :, :, 1]
foo = temp
foo = foo[Ly: Ly+nblv, Lx: Lx+nblh]
Y[:, n] = foo.flatten()
C_u = torch.mm(torch.transpose(Y, 0, 1), Y) / nexp.type(self.samplet.type())
eig_values, H = torch.eig(C_u, eigenvectors=True)
eig_values = eig_values.type(self.samplet.type())
H = H.type(self.samplet.type())
if self.use_double:
L = torch.diag(eig_values[:, 0] * (eig_values[:, 0] > 0).double()) * torch.sum(eig_values) / ((torch.sum(eig_values[:,0] * (eig_values[:, 0] > 0).double())) + (torch.sum(eig_values[:, 0] * (eig_values[:, 0] > 0).double())==0))
else:
L = torch.diag(eig_values[:, 0] * (eig_values[:, 0] > 0).float()) * torch.sum(eig_values) / ((torch.sum(eig_values[:,0] * (eig_values[:, 0] > 0).float())) + (torch.sum(eig_values[:, 0] * (eig_values[:, 0] > 0).float())==0))
C_u = torch.mm(torch.mm(H, L), torch.transpose(H, 0, 1))
C_u_inv = torch.inverse(C_u)
ss = (torch.mm(Y, C_u_inv))*Y / N.type(self.samplet.type())
ss = torch.sum(ss, 1)
ss = ss.view(nblv, nblh)
ss = ss.unsqueeze(0).unsqueeze(0)
g = g[:, :, Ly: Ly+nblv, Lx: Lx+nblh]
vv = vv[:, :, Ly: Ly+nblv, Lx: Lx+nblh]
# Calculate mutual information
infow = torch.zeros(g.shape).type(self.samplet.type())
for j in range(len(eig_values)):
infow = infow + torch.log2(1 + ((vv + (1 + g*g)*self.sigma_nsq)*ss*eig_values[j, 0]+self.sigma_nsq*vv) / (self.sigma_nsq*self.sigma_nsq))
infow[infow < tol] = 0
iw_map[scale] = infow
return iw_map