def forward()

in segmentation/model/psanet.py [0:0]


    def forward(self, x):
        out = x
        if self.psa_type in [0, 1]:
            x = self.reduce(x)
            n, c, h, w = x.size()
            if self.shrink_factor != 1:
                h = (h - 1) // self.shrink_factor + 1
                w = (w - 1) // self.shrink_factor + 1
                x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
            y = self.attention(x)
            if self.compact:
                if self.psa_type == 1:
                    y = y.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w)
            else:
                y = PF.psa_mask(y, self.psa_type, self.mask_h, self.mask_w)
            if self.psa_softmax:
                y = F.softmax(y, dim=1)
            x = torch.bmm(x.view(n, c, h * w), y.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
        elif self.psa_type == 2:
            x_col = self.reduce(x)
            x_dis = self.reduce_p(x)
            n, c, h, w = x_col.size()
            if self.shrink_factor != 1:
                h = (h - 1) // self.shrink_factor + 1
                w = (w - 1) // self.shrink_factor + 1
                x_col = F.interpolate(x_col, size=(h, w), mode='bilinear', align_corners=True)
                x_dis = F.interpolate(x_dis, size=(h, w), mode='bilinear', align_corners=True)
            y_col = self.attention(x_col)
            y_dis = self.attention_p(x_dis)
            if self.compact:
                y_dis = y_dis.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w)
            else:
                y_col = PF.psa_mask(y_col, 0, self.mask_h, self.mask_w)
                y_dis = PF.psa_mask(y_dis, 1, self.mask_h, self.mask_w)
            if self.psa_softmax:
                y_col = F.softmax(y_col, dim=1)
                y_dis = F.softmax(y_dis, dim=1)
            x_col = torch.bmm(x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
            x_dis = torch.bmm(x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
            x = torch.cat([x_col, x_dis], 1)
        x = self.proj(x)
        if self.shrink_factor != 1:
            h = (h - 1) * self.shrink_factor + 1
            w = (w - 1) * self.shrink_factor + 1
            x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
        return torch.cat((out, x), 1)