in segmentation/model/psanet.py [0:0]
def forward(self, x):
out = x
if self.psa_type in [0, 1]:
x = self.reduce(x)
n, c, h, w = x.size()
if self.shrink_factor != 1:
h = (h - 1) // self.shrink_factor + 1
w = (w - 1) // self.shrink_factor + 1
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
y = self.attention(x)
if self.compact:
if self.psa_type == 1:
y = y.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w)
else:
y = PF.psa_mask(y, self.psa_type, self.mask_h, self.mask_w)
if self.psa_softmax:
y = F.softmax(y, dim=1)
x = torch.bmm(x.view(n, c, h * w), y.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
elif self.psa_type == 2:
x_col = self.reduce(x)
x_dis = self.reduce_p(x)
n, c, h, w = x_col.size()
if self.shrink_factor != 1:
h = (h - 1) // self.shrink_factor + 1
w = (w - 1) // self.shrink_factor + 1
x_col = F.interpolate(x_col, size=(h, w), mode='bilinear', align_corners=True)
x_dis = F.interpolate(x_dis, size=(h, w), mode='bilinear', align_corners=True)
y_col = self.attention(x_col)
y_dis = self.attention_p(x_dis)
if self.compact:
y_dis = y_dis.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w)
else:
y_col = PF.psa_mask(y_col, 0, self.mask_h, self.mask_w)
y_dis = PF.psa_mask(y_dis, 1, self.mask_h, self.mask_w)
if self.psa_softmax:
y_col = F.softmax(y_col, dim=1)
y_dis = F.softmax(y_dis, dim=1)
x_col = torch.bmm(x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
x_dis = torch.bmm(x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
x = torch.cat([x_col, x_dis], 1)
x = self.proj(x)
if self.shrink_factor != 1:
h = (h - 1) * self.shrink_factor + 1
w = (w - 1) * self.shrink_factor + 1
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
return torch.cat((out, x), 1)