def forward()

in models/gpvit/gpvit_adapter.py [0:0]


    def forward(self, x):
        assert isinstance(x, list), type(x)
        if len(x) == 2:
            x, (c2, c3, c4) = x
            clusters = None
        else:
            # mask: tensor(bool), shape(bs,h//8,w//8)
            x, (c2, c3, c4), clusters = x
        
        # 双向Deformable Attention
        deform_inputs1, deform_inputs2 = deform_inputs(x)
        
        B, C, H, W = x.shape
        mask_indices, cluster_size_ratio = self.build_mask_indices(clusters, (H//8, W//8))

        B = x.shape[0]
        x, patch_resolution = self.patch_embed(x)  # 8倍下采样
        # x: shape(1, h/8*w/8, ndim), serve as query
        assert tuple(patch_resolution) == (H // 8, W // 8)
        H, W = patch_resolution
        bs, n, dim = x.shape
        pos_embed = resize_pos_embed(
            self.pos_embed,
            self.patch_resolution,
            patch_resolution,
            mode=self.interpolate_mode,
            num_extra_tokens=0)

        x = x + pos_embed
        x = self.drop_after_pos(x)

        if mask_indices is not None and True:
            # update (slice) features and indices
            assert cluster_size_ratio == 8
            deform_inputs1, deform_inputs2 = \
                deform_inputs(torch.zeros((0, 0, H, W), device=x.device))
            patch_resolution = (H // cluster_size_ratio, W // cluster_size_ratio)
            H, W = patch_resolution
            x, c2, c3, c4 = self.feat_slice2([x, c2, c3, c4], [1, 1, 2, 4], clusters, patch_resolution)
            bs = x.size(0)
            mask_indices, cluster_size_ratio = None, None

        # SPM forward,独立的特征金字塔,下采样率为8/16/32
        # c: shape(bs, h/8*w/8 + h/16*w/16 + h/32*w/32, ndim), serve as feature
        c = torch.cat([c2, c3, c4], dim=1)

        # Interaction
        for i, layer in enumerate(self.interactions):
            indexes = self.interaction_indexes[i]
            x, c = layer(x, c, self.layers[indexes[0]:indexes[-1] + 1],
                         deform_inputs1, deform_inputs2, patch_resolution, 
                         mask_indices=mask_indices, cluster_size_ratio=cluster_size_ratio)

        # Split & Reshape
        c2 = c[:, 0:c2.size(1), :]
        c3 = c[:, c2.size(1):c2.size(1) + c3.size(1), :]
        c4 = c[:, c2.size(1) + c3.size(1):, :]

        c2 = c2.transpose(1, 2).view(bs, dim, H, W).contiguous()
        c3 = c3.transpose(1, 2).view(bs, dim, H // 2, W // 2).contiguous()
        c4 = c4.transpose(1, 2).view(bs, dim, H // 4, W // 4).contiguous()

        if self.add_vit_feature:
            x2 = x.transpose(1, 2).view(bs, dim, H, W).contiguous()
            x3 = F.interpolate(x2, scale_factor=0.5, mode='bilinear', align_corners=False)
            x4 = F.interpolate(x2, scale_factor=0.25, mode='bilinear', align_corners=False)
            c2, c3, c4 = c2 + x2, c3 + x3, c4 + x4

        # Final Norm
        f2 = self.ad_norm2(c2)
        f3 = self.ad_norm3(c3)
        f4 = self.ad_norm4(c4)

        # torch.cuda.synchronize()
        # t0 = time.time()
        if mask_indices is not None:
            f2, f3, f4 = self.feat_slice([f2, f3, f4], clusters, [1, 2, 4])
        # torch.cuda.synchronize()
        # t1 = time.time()
        # print(f"Feature slicing cost {(t1-t0)*1000:.2f}ms")  # 7ms

        return [f2, f3, f4]