def forward()

in mobile_cv/arch/fbnet_v2/asymmetric_correlation.py [0:0]


    def forward(self, x1, x2):
        """Calculates the correlation between x1, x2 in a neighborhood of
        (-dw_neg, dw_pos, -dh_neg, dh_pos)

        Correlation if calulating outside of feature map is 0 (implicit zero padding)
        Ignores kernel size.

        Inputs:
            x1 (tensor): feature map1, size (n, c, h, w)
            x2 (tensor): feature map2, size (n, c, h, w)
        Returns:
            corr (tensor): correlations, size
                        (n, ((dw_pos + dw_neg)/s2 + 1)*((dh_pos + dh_neg)/s2 + 1), h / s1, w / s1)
        """
        assert x1.shape == x2.shape

        n, c, h, w = x1.shape
        dw_pos, dw_neg, dh_pos, dh_neg = (
            self.dw_pos,
            self.dw_neg,
            self.dh_pos,
            self.dh_neg,
        )
        s1, s2 = self.s1, self.s2
        out_c = ((dw_pos + dw_neg) // s2 + 1) * ((dh_pos + dh_neg) // s2 + 1)
        corr = torch.empty(n, out_c, math.ceil(h / s1), math.ceil(w / s1)).to(x1.device)

        for _n in range(n):
            # keep track of the output channel index
            _d = 0
            for _dh in range(-dh_neg, dh_pos + 1, s2):
                for _dw in range(-dw_neg, dw_pos + 1, s2):
                    _outh = 0
                    for _h in range(0, h, s1):
                        _outw = 0
                        for _w in range(0, w, s1):
                            # implicit zero padding by checking if we are computing
                            # correlation of point outside of bounds and returning 0
                            if (
                                _h + _dh < 0
                                or _h + _dh >= h
                                or _w + _dw < 0
                                or _w + _dw >= w
                            ):
                                corr[_n, _d, _outh, _outw] = 0.0
                            else:
                                corr[_n, _d, _outh, _outw] = self.div.mul_scalar(
                                    torch.dot(
                                        x1[_n, :, _h, _w], x2[_n, :, _h + _dh, _w + _dw]
                                    ),
                                    float(1.0 / c),
                                )
                            _outw += 1
                        _outh += 1
                    _d += 1
        return corr