in models/gpvit/gpvit_adapter.py [0:0]
def forward(self, x):
assert isinstance(x, list), type(x)
if len(x) == 2:
x, (c2, c3, c4) = x
clusters = None
else:
# mask: tensor(bool), shape(bs,h//8,w//8)
x, (c2, c3, c4), clusters = x
# 双向Deformable Attention
deform_inputs1, deform_inputs2 = deform_inputs(x)
B, C, H, W = x.shape
mask_indices, cluster_size_ratio = self.build_mask_indices(clusters, (H//8, W//8))
B = x.shape[0]
x, patch_resolution = self.patch_embed(x) # 8倍下采样
# x: shape(1, h/8*w/8, ndim), serve as query
assert tuple(patch_resolution) == (H // 8, W // 8)
H, W = patch_resolution
bs, n, dim = x.shape
pos_embed = resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=0)
x = x + pos_embed
x = self.drop_after_pos(x)
if mask_indices is not None and True:
# update (slice) features and indices
assert cluster_size_ratio == 8
deform_inputs1, deform_inputs2 = \
deform_inputs(torch.zeros((0, 0, H, W), device=x.device))
patch_resolution = (H // cluster_size_ratio, W // cluster_size_ratio)
H, W = patch_resolution
x, c2, c3, c4 = self.feat_slice2([x, c2, c3, c4], [1, 1, 2, 4], clusters, patch_resolution)
bs = x.size(0)
mask_indices, cluster_size_ratio = None, None
# SPM forward,独立的特征金字塔,下采样率为8/16/32
# c: shape(bs, h/8*w/8 + h/16*w/16 + h/32*w/32, ndim), serve as feature
c = torch.cat([c2, c3, c4], dim=1)
# Interaction
for i, layer in enumerate(self.interactions):
indexes = self.interaction_indexes[i]
x, c = layer(x, c, self.layers[indexes[0]:indexes[-1] + 1],
deform_inputs1, deform_inputs2, patch_resolution,
mask_indices=mask_indices, cluster_size_ratio=cluster_size_ratio)
# Split & Reshape
c2 = c[:, 0:c2.size(1), :]
c3 = c[:, c2.size(1):c2.size(1) + c3.size(1), :]
c4 = c[:, c2.size(1) + c3.size(1):, :]
c2 = c2.transpose(1, 2).view(bs, dim, H, W).contiguous()
c3 = c3.transpose(1, 2).view(bs, dim, H // 2, W // 2).contiguous()
c4 = c4.transpose(1, 2).view(bs, dim, H // 4, W // 4).contiguous()
if self.add_vit_feature:
x2 = x.transpose(1, 2).view(bs, dim, H, W).contiguous()
x3 = F.interpolate(x2, scale_factor=0.5, mode='bilinear', align_corners=False)
x4 = F.interpolate(x2, scale_factor=0.25, mode='bilinear', align_corners=False)
c2, c3, c4 = c2 + x2, c3 + x3, c4 + x4
# Final Norm
f2 = self.ad_norm2(c2)
f3 = self.ad_norm3(c3)
f4 = self.ad_norm4(c4)
# torch.cuda.synchronize()
# t0 = time.time()
if mask_indices is not None:
f2, f3, f4 = self.feat_slice([f2, f3, f4], clusters, [1, 2, 4])
# torch.cuda.synchronize()
# t1 = time.time()
# print(f"Feature slicing cost {(t1-t0)*1000:.2f}ms") # 7ms
return [f2, f3, f4]