in mobile_cv/arch/fbnet_v2/asymmetric_correlation.py [0:0]
def forward(self, x1, x2):
"""Calculates the correlation between x1, x2 in a neighborhood of
(-dw_neg, dw_pos, -dh_neg, dh_pos)
Correlation if calulating outside of feature map is 0 (implicit zero padding)
Ignores kernel size.
Inputs:
x1 (tensor): feature map1, size (n, c, h, w)
x2 (tensor): feature map2, size (n, c, h, w)
Returns:
corr (tensor): correlations, size
(n, ((dw_pos + dw_neg)/s2 + 1)*((dh_pos + dh_neg)/s2 + 1), h / s1, w / s1)
"""
assert x1.shape == x2.shape
n, c, h, w = x1.shape
dw_pos, dw_neg, dh_pos, dh_neg = (
self.dw_pos,
self.dw_neg,
self.dh_pos,
self.dh_neg,
)
s1, s2 = self.s1, self.s2
out_c = ((dw_pos + dw_neg) // s2 + 1) * ((dh_pos + dh_neg) // s2 + 1)
corr = torch.empty(n, out_c, math.ceil(h / s1), math.ceil(w / s1)).to(x1.device)
for _n in range(n):
# keep track of the output channel index
_d = 0
for _dh in range(-dh_neg, dh_pos + 1, s2):
for _dw in range(-dw_neg, dw_pos + 1, s2):
_outh = 0
for _h in range(0, h, s1):
_outw = 0
for _w in range(0, w, s1):
# implicit zero padding by checking if we are computing
# correlation of point outside of bounds and returning 0
if (
_h + _dh < 0
or _h + _dh >= h
or _w + _dw < 0
or _w + _dw >= w
):
corr[_n, _d, _outh, _outw] = 0.0
else:
corr[_n, _d, _outh, _outw] = self.div.mul_scalar(
torch.dot(
x1[_n, :, _h, _w], x2[_n, :, _h + _dh, _w + _dw]
),
float(1.0 / c),
)
_outw += 1
_outh += 1
_d += 1
return corr