in models/common.py [0:0]
def ada_slicer(self, mask_pred: torch.Tensor, ratio=8, threshold=0.3): # better
# t0 = time_synchronized()
bs, height, width = mask_pred.shape
device, dtype = mask_pred.device, mask_pred.dtype
cluster_wh = max(make_divisible(width / ratio, 4), make_divisible(height / ratio, 4)) # 保证正方形
cluster_w, cluster_h = cluster_wh, cluster_wh
# cluster_w, cluster_h = make_divisible(width / ratio, 4), make_divisible(height / ratio, 4)
half_clus_w, half_clus_h = cluster_w // 2, cluster_h // 2
outs = []
# t1 = time_synchronized()
activated = mask_pred >= threshold
maxima = F.max_pool2d(mask_pred, 3, stride=1, padding=1) == mask_pred
obj_centers = activated & maxima
padding = half_clus_w // 2
obj_sizes = F.avg_pool2d(mask_pred, padding * 2 + 1, stride=1, padding=padding)
# bi, yi, xi
# t2 = time_synchronized()
cb, cy, cx = obj_centers.nonzero(as_tuple=True)
obj_sizes = obj_sizes[cb, cy, cx]
# t3 = time_synchronized()
for bi in range(bs):
ci = cb == bi
cn = ci.sum().item()
if cn == 0:
outs.append(torch.zeros((0, 4), device=device))
continue
if bs == 1:
sizes = obj_sizes
cy_bi, cx_bi = cy, cx
else:
sizes = obj_sizes[ci]
cy_bi, cx_bi = cy[ci], cx[ci]
# shape(n,1)
init_x1 = cx_bi.clamp(half_clus_w, width - half_clus_w) - half_clus_w
init_y1 = cy_bi.clamp(half_clus_h, height - half_clus_h) - half_clus_h
# shape(1,m)
if not hasattr(self, 'grid') or self.grid is None or self.grid[0].shape[-1] != cluster_h*cluster_w:
gy, gx = torch.meshgrid(torch.arange(cluster_h), torch.arange(cluster_w))
self.grid = (gy.reshape(1, -1).to(device), gx.reshape(1, -1).to(device))
gy, gx = self.grid
# shape(n,m)
act_x, act_y = (init_x1.view(-1, 1) + gx).view(-1), (init_y1.view(-1, 1) + gy).view(-1)
act = activated[bi, act_y, act_x].view(cn, cluster_h, cluster_w)
# t4 = time_synchronized()
act_x, act_y = act.any(dim=1).long(), act.any(dim=2).long()
dx1, dx2 = (1 - act_x).argmin(dim=1), -(1 - act_x.flip((1,))).argmin(dim=1)
dy1, dy2 = (1 - act_y).argmin(dim=1), -(1 - act_y.flip((1,))).argmin(dim=1)
dx = torch.where(dx1.abs() > dx2.abs(), dx1, dx2)
dy = torch.where(dy1.abs() > dy2.abs(), dy1, dy2)
# t5 = time_synchronized()
refine_x1, refine_y1 = (init_x1 + dx).clamp(0, width - cluster_w).to(dtype), \
(init_y1 + dy).clamp(0, height - cluster_h).to(dtype)
refine_x2, refine_y2 = refine_x1 + cluster_w, refine_y1 + cluster_h
total_clusters = torch.stack((refine_x1, refine_y1, refine_x2, refine_y2), dim=1).long()
# i = torchvision.ops.nms(total_clusters, sizes, 0.8) # NMS
# clusters = total_clusters[i].long()
# t6 = time_synchronized()
overlap = (refine_x1[:, None] <= cx_bi[None, :]) & (cx_bi[None, :] < refine_x2[:, None]) & \
(refine_y1[:, None] <= cy_bi[None, :]) & (cy_bi[None, :] < refine_y2[:, None])
clusters = []
contained = torch.full_like(overlap[0], False)
for max_i in torch.argsort(sizes, descending=True):
if contained[max_i]:
continue
clusters.append(total_clusters[max_i])
contained |= overlap[max_i]
# t7 = time_synchronized()
outs.append(torch.stack(clusters) if len(clusters) else torch.zeros_like(total_clusters[:0, :]))
# print(f't1: {(t1-t0)*1000:.3f}, t2: {(t2-t1)*1000:.3f}, t3: {(t3-t2)*1000:.3f}, t4: {(t4-t3)*1000:.3f}, t5: {(t5-t4)*1000:.3f}, t6: {(t6-t5)*1000:.3f}, t7: {(t7-t6)*1000:.3f}')
return outs