def encode_decode()

in models/UN_EPT.py [0:0]


    def encode_decode(self, x):

        bsize, c, h, w = x.shape
        backbone_feats = self.backbone(x)
        if self.backbone_type == 'ResNetV1c':
            backbone_feats = backbone_feats[1:]
       
        context = self.spatial_branch(x)
        
        mask_map = self.mask_head(context)
        dir_map = self.dir_head(context)

        context = context.flatten(2).permute(2, 0, 1)

        pyramid_feats = [] 
        for i, conv_layer in enumerate(self.layers):
            feature = conv_layer(backbone_feats[i]) 
            pyramid_feats.append(feature)

        q_H = q_W = int(math.sqrt(self.num_queries))
        out = self.context_branch(pyramid_feats, context, self.query_embed.weight, q_H, q_W) 
        
        out = out.unsqueeze(0).reshape([h//8, w//8, bsize, self.feat_dim]).permute(2, 3, 0, 1)

        out = self.cls(out)

        seg_logits = resize(
            input=out,
            size=x.shape[2:],
            mode='bilinear',
            align_corners=self.align_corners)

        return seg_logits, mask_map, dir_map