def ada_slicer()

in models/common.py [0:0]


    def ada_slicer(self, mask_pred: torch.Tensor, ratio=8, threshold=0.3):   # better     
        # t0 = time_synchronized()
        bs, height, width = mask_pred.shape
        device, dtype = mask_pred.device, mask_pred.dtype
        cluster_wh = max(make_divisible(width / ratio, 4), make_divisible(height / ratio, 4))  # 保证正方形
        cluster_w, cluster_h = cluster_wh, cluster_wh
        # cluster_w, cluster_h = make_divisible(width / ratio, 4), make_divisible(height / ratio, 4)
        half_clus_w,  half_clus_h = cluster_w // 2, cluster_h // 2
        outs = []

        # t1 = time_synchronized()
        activated = mask_pred >= threshold
        maxima = F.max_pool2d(mask_pred, 3, stride=1, padding=1) == mask_pred
        obj_centers = activated & maxima
        padding = half_clus_w // 2
        obj_sizes = F.avg_pool2d(mask_pred, padding * 2 + 1, stride=1, padding=padding)
        
        # bi, yi, xi
        # t2 = time_synchronized()
        cb, cy, cx = obj_centers.nonzero(as_tuple=True)
        obj_sizes = obj_sizes[cb, cy, cx]

        # t3 = time_synchronized()
        for bi in range(bs):
            ci = cb == bi
            cn = ci.sum().item()
            if cn == 0:
                outs.append(torch.zeros((0, 4), device=device))
                continue

            if bs == 1:
                sizes = obj_sizes
                cy_bi, cx_bi = cy, cx
            else:
                sizes = obj_sizes[ci]
                cy_bi, cx_bi = cy[ci], cx[ci]
                
            # shape(n,1)
            init_x1 = cx_bi.clamp(half_clus_w, width - half_clus_w) - half_clus_w
            init_y1 = cy_bi.clamp(half_clus_h, height - half_clus_h) - half_clus_h

            # shape(1,m)
            if not hasattr(self, 'grid') or self.grid is None or self.grid[0].shape[-1] != cluster_h*cluster_w:
                gy, gx = torch.meshgrid(torch.arange(cluster_h), torch.arange(cluster_w))
                self.grid = (gy.reshape(1, -1).to(device), gx.reshape(1, -1).to(device))
            gy, gx = self.grid

            # shape(n,m)
            act_x, act_y = (init_x1.view(-1, 1) + gx).view(-1), (init_y1.view(-1, 1) + gy).view(-1)
            act = activated[bi, act_y, act_x].view(cn, cluster_h, cluster_w)
            
            # t4 = time_synchronized()
            act_x, act_y = act.any(dim=1).long(), act.any(dim=2).long()
            dx1, dx2 = (1 - act_x).argmin(dim=1), -(1 - act_x.flip((1,))).argmin(dim=1)
            dy1, dy2 = (1 - act_y).argmin(dim=1), -(1 - act_y.flip((1,))).argmin(dim=1)
            dx = torch.where(dx1.abs() > dx2.abs(), dx1, dx2)
            dy = torch.where(dy1.abs() > dy2.abs(), dy1, dy2)
            
            # t5 = time_synchronized()
            refine_x1, refine_y1 = (init_x1 + dx).clamp(0, width - cluster_w).to(dtype), \
                                    (init_y1 + dy).clamp(0, height - cluster_h).to(dtype)
            refine_x2, refine_y2 = refine_x1 + cluster_w, refine_y1 + cluster_h
            total_clusters = torch.stack((refine_x1, refine_y1, refine_x2, refine_y2), dim=1).long()
            
            # i = torchvision.ops.nms(total_clusters, sizes, 0.8)  # NMS
            # clusters = total_clusters[i].long()

            # t6 = time_synchronized()
            overlap = (refine_x1[:, None] <= cx_bi[None, :]) & (cx_bi[None, :] < refine_x2[:, None]) & \
                      (refine_y1[:, None] <= cy_bi[None, :]) & (cy_bi[None, :] < refine_y2[:, None])
            clusters = []
            contained = torch.full_like(overlap[0], False)
            for max_i in torch.argsort(sizes, descending=True):
                if contained[max_i]:
                    continue
                clusters.append(total_clusters[max_i])
                contained |= overlap[max_i]

            # t7 = time_synchronized()
            outs.append(torch.stack(clusters) if len(clusters) else torch.zeros_like(total_clusters[:0, :]))
    
            # print(f't1: {(t1-t0)*1000:.3f}, t2: {(t2-t1)*1000:.3f}, t3: {(t3-t2)*1000:.3f}, t4: {(t4-t3)*1000:.3f}, t5: {(t5-t4)*1000:.3f}, t6: {(t6-t5)*1000:.3f}, t7: {(t7-t6)*1000:.3f}')
        return outs