in models/UN_EPT.py [0:0]
def encode_decode(self, x):
bsize, c, h, w = x.shape
backbone_feats = self.backbone(x)
if self.backbone_type == 'ResNetV1c':
backbone_feats = backbone_feats[1:]
context = self.spatial_branch(x)
mask_map = self.mask_head(context)
dir_map = self.dir_head(context)
context = context.flatten(2).permute(2, 0, 1)
pyramid_feats = []
for i, conv_layer in enumerate(self.layers):
feature = conv_layer(backbone_feats[i])
pyramid_feats.append(feature)
q_H = q_W = int(math.sqrt(self.num_queries))
out = self.context_branch(pyramid_feats, context, self.query_embed.weight, q_H, q_W)
out = out.unsqueeze(0).reshape([h//8, w//8, bsize, self.feat_dim]).permute(2, 3, 0, 1)
out = self.cls(out)
seg_logits = resize(
input=out,
size=x.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
return seg_logits, mask_map, dir_map