models/gpvit/gpvit_adapter.py (430 lines of code) (raw):
# Copyright (c) Shanghai AI Lab. All rights reserved.
import logging
import math
from functools import partial
from typing import Optional, List, Tuple
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models.builder import BACKBONES
from timm.models.layers import DropPath, trunc_normal_
from torch.nn.init import normal_
from mmcls.gpvit_dev.models.backbones.gpvit import GPViT, resize_pos_embed
from .adapter_modules import SpatialPriorModule, InteractionBlock, get_reference_points, MSDeformAttn
_logger = logging.getLogger(__name__)
@BACKBONES.register_module(force=True)
class GPViTAdapter(GPViT):
def __init__(self,
pretrain_size=224,
conv_inplane=64,
n_points=4,
deform_num_heads=6,
init_values=0.,
interaction_indexes=None,
with_cffn=True,
cffn_ratio=0.25,
deform_ratio=1.0,
add_vit_feature=True,
use_extra_extractor=True,
att_with_cp=False,
group_with_cp=False,
*args,
**kwargs):
self.att_with_cp = att_with_cp
self.group_with_cp = group_with_cp
super().__init__(*args, **kwargs)
self.num_classes = 80
self.cls_token = None
self.num_block = len(self.layers)
self.pretrain_size = (pretrain_size, pretrain_size)
self.interaction_indexes = interaction_indexes
self.add_vit_feature = add_vit_feature
embed_dim = self.embed_dims
self.interactions = nn.Sequential(*[
InteractionBlock_GPViT(
dim=embed_dim,
num_heads=deform_num_heads,
n_points=n_points,
init_values=init_values,
drop_path=self.drop_path_rate,
# norm_layer=self.norm1,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
deform_ratio=deform_ratio,
extra_extractor=((True if i == len(interaction_indexes) - 1 else False) and use_extra_extractor),
down_stride=8
)
for i in range(len(interaction_indexes))
])
self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2)
self.ad_norm1 = nn.BatchNorm2d(embed_dim)
self.ad_norm2 = nn.BatchNorm2d(embed_dim)
self.ad_norm3 = nn.BatchNorm2d(embed_dim)
self.ad_norm4 = nn.BatchNorm2d(embed_dim)
self.up.apply(self._init_weights)
self.interactions.apply(self._init_weights)
self.apply(self._init_deform_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def _init_deform_weights(self, m):
if isinstance(m, MSDeformAttn):
m._reset_parameters()
def _get_pos_embed(self, pos_embed, H, W):
pos_embed = pos_embed.reshape(
1, self.pretrain_size[0] // 16, self.pretrain_size[1] // 16, -1).permute(0, 3, 1, 2)
pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False).\
reshape(1, -1, H * W).permute(0, 2, 1)
return pos_embed
def forward(self, x):
deform_inputs1, deform_inputs2 = deform_inputs(x)
# SPM forward
c1, c2, c3, c4 = self.spm(x) # s4, s8, s16, s32
c2, c3, c4 = self._add_level_embed(c2, c3, c4)
c = torch.cat([c2, c3, c4], dim=1)
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
H, W = patch_resolution
bs, n, dim = x.shape
pos_embed = resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=0)
x = x + pos_embed
x = self.drop_after_pos(x)
# Interaction
for i, layer in enumerate(self.interactions):
indexes = self.interaction_indexes[i]
x, c = layer(x, c, self.layers[indexes[0]:indexes[-1] + 1],
deform_inputs1, deform_inputs2, patch_resolution)
# Split & Reshape
c2 = c[:, 0:c2.size(1), :]
c3 = c[:, c2.size(1):c2.size(1) + c3.size(1), :]
c4 = c[:, c2.size(1) + c3.size(1):, :]
c2 = c2.transpose(1, 2).view(bs, dim, H, W).contiguous()
c3 = c3.transpose(1, 2).view(bs, dim, H // 2, W // 2).contiguous()
c4 = c4.transpose(1, 2).view(bs, dim, H // 4, W // 4).contiguous()
c1 = self.up(c2) + c1
if self.add_vit_feature:
x2 = x.transpose(1, 2).view(bs, dim, H, W).contiguous()
x1 = F.interpolate(x2, scale_factor=2, mode='bilinear', align_corners=False)
x3 = F.interpolate(x2, scale_factor=0.5, mode='bilinear', align_corners=False)
x4 = F.interpolate(x2, scale_factor=0.25, mode='bilinear', align_corners=False)
c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4
# Final Norm
f1 = self.ad_norm1(c1)
f2 = self.ad_norm2(c2)
f3 = self.ad_norm3(c3)
f4 = self.ad_norm4(c4)
return [f1, f2, f3, f4]
@BACKBONES.register_module(force=True)
class GPViTAdapterSingleStageESOD(GPViTAdapter):
def __init__(self,
arch="L2",
pretrain_size=224,
conv_inplane=64,
n_points=4,
deform_num_heads=6,
init_values=0.,
interaction_indexes=[[0, 2], [3, 5], [6, 8], [9, 11]], ## update
with_cffn=True,
cffn_ratio=0.25,
deform_ratio=1.0,
add_vit_feature=True,
use_extra_extractor=True,
att_with_cp=False,
group_with_cp=False,
*args,
**kwargs):
self.att_with_cp = att_with_cp
self.group_with_cp = group_with_cp
## update
kwargs.update({'arch': arch, "drop_path_rate": 0.1, "out_indices": (11,),
"final_norm": False, "convert_syncbn": False})
super(GPViTAdapter, self).__init__(*args, **kwargs)
self.num_classes = 80
self.cls_token = None
self.num_block = len(self.layers)
self.pretrain_size = (pretrain_size, pretrain_size)
self.interaction_indexes = interaction_indexes
self.add_vit_feature = add_vit_feature
embed_dim = self.embed_dims
self.interactions = nn.Sequential(*[
InteractionBlock_GPViT(
dim=embed_dim,
num_heads=deform_num_heads,
n_points=n_points,
init_values=init_values,
drop_path=self.drop_path_rate,
# norm_layer=self.norm1,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
deform_ratio=deform_ratio,
extra_extractor=((True if i == len(interaction_indexes) - 1 else False) and use_extra_extractor),
down_stride=8
)
for i in range(len(interaction_indexes))
])
self.ad_norm2 = nn.BatchNorm2d(embed_dim)
self.ad_norm3 = nn.BatchNorm2d(embed_dim)
self.ad_norm4 = nn.BatchNorm2d(embed_dim)
self.interactions.apply(self._init_weights)
self.apply(self._init_deform_weights)
def build_mask_indices(self, clusters: Optional[torch.Tensor], feat_size: Tuple[int, int]):
if clusters is None:
return None, None
assert clusters.size(0) > 0
H, W = feat_size
device = clusters.device
bi, x1, y1, x2, y2 = clusters.chunk(5, dim=1) # shape(n,1)
w, h = int(x2[0, 0] - x1[0, 0]), int(y2[0, 0] - y1[0, 0])
if getattr(self, 'slice_ratio', None) is None:
self.slice_ratio = (H // h) #.item()
assert H // h == W // w == self.slice_ratio
if getattr(self, 'grid_off', None) is None or self.grid_off.size(1) != w*h:
gy, gx = torch.meshgrid(torch.arange(h), torch.arange(w))
gxy = torch.stack((gy.reshape(-1), gx.reshape(-1)), dim=0) # shape(2, w*h)
self.grid_off = gxy.to(device)
gy, gx = self.grid_off.chunk(2, dim=0) # shape(1,w*h)
mask_indices = (bi * H * W + (gy + y1) * W + (gx + x1)).view(-1) # shape(n * w*h)
return mask_indices, self.slice_ratio
def feat_slice(self, featmaps: List[torch.Tensor], clusters: torch.Tensor, scales: List[int]):
assert len(featmaps) == len(scales)
featmaps_new = []
B, C, H, W = featmaps[0].shape
device = featmaps[0].device
bi, x1, y1, x2, y2 = clusters.chunk(5, dim=1) # shape(n,1)
w, h = (x2[0, 0] - x1[0, 0]).item(), (y2[0, 0] - y1[0, 0]).item()
assert H // h == W // w == self.slice_ratio
if getattr(self, 'fs_grids', None) is None or self.fs_grids[0].size(1) != w*h:
self.fs_grids = []
for s in scales:
gy, gx = torch.meshgrid(torch.arange(h//s), torch.arange(w//s))
gxy = torch.stack((gy.reshape(-1), gx.reshape(-1)), dim=0) # shape(2, w*h)
self.fs_grids.append(gxy.to(device))
for fi, s in enumerate(scales):
t, l = y1 // s, x1 // s
gj, gi = self.fs_grids[fi].chunk(2, dim=0) # shape(1,w*h)
H_, W_ = H // s, W // s
mask_indices = (bi * H_ * W_ + (gj + t) * W_ + (gi + l)).view(-1) # shape(n * w*h)
fm = featmaps[fi].flatten(2).transpose(1, 2).contiguous() # (B,C,H,W) -> (B,H*W,C)
fm = fm.view(-1, C)[mask_indices].view(-1, h//s, w//s, C)
fm = fm.permute(0, 3, 1, 2).contiguous()
featmaps_new.append(fm)
return featmaps_new
def feat_slice2(self, featembs: List[torch.Tensor], scales: List[int],
clusters: torch.Tensor, mask_patch_resolution: List[int]):
assert len(featembs) == len(scales)
h, w = mask_patch_resolution # shape of feature patch
H, W = h * self.slice_ratio, w * self.slice_ratio # shape of feature map
device = featembs[0].device
featembs_new = []
if getattr(self, 'fs_grids', None) is None or self.fs_grids[0].size(1) != h*w:
self.fs_grids = []
for s in scales:
gy, gx = torch.meshgrid(torch.arange(h//s), torch.arange(w//s))
gxy = torch.stack((gy.reshape(-1), gx.reshape(-1)), dim=0) # shape(2, w*h)
self.fs_grids.append(gxy.to(device))
bi, x1, y1, x2, y2 = clusters.chunk(5, dim=1) # shape(n,1)
for fi, (fm, s) in enumerate(zip(featembs, scales)):
B, L, C = fm.shape
fm = fm.view(-1, C)
t, l = y1 // s, x1 // s
gj, gi = self.fs_grids[fi].chunk(2, dim=0) # shape(1,w*h)
H_, W_ = H // s, W // s
mask_indices = (bi * H_ * W_ + (gj + t) * W_ + (gi + l)).view(-1) # shape(n * w*h)
fm = fm[mask_indices].contiguous().view(-1, (h//s)*(w//s), C)
featembs_new.append(fm)
return featembs_new
def forward(self, x):
assert isinstance(x, list), type(x)
if len(x) == 2:
x, (c2, c3, c4) = x
clusters = None
else:
# mask: tensor(bool), shape(bs,h//8,w//8)
x, (c2, c3, c4), clusters = x
# 双向Deformable Attention
deform_inputs1, deform_inputs2 = deform_inputs(x)
B, C, H, W = x.shape
mask_indices, cluster_size_ratio = self.build_mask_indices(clusters, (H//8, W//8))
B = x.shape[0]
x, patch_resolution = self.patch_embed(x) # 8倍下采样
# x: shape(1, h/8*w/8, ndim), serve as query
assert tuple(patch_resolution) == (H // 8, W // 8)
H, W = patch_resolution
bs, n, dim = x.shape
pos_embed = resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=0)
x = x + pos_embed
x = self.drop_after_pos(x)
if mask_indices is not None and True:
# update (slice) features and indices
assert cluster_size_ratio == 8
deform_inputs1, deform_inputs2 = \
deform_inputs(torch.zeros((0, 0, H, W), device=x.device))
patch_resolution = (H // cluster_size_ratio, W // cluster_size_ratio)
H, W = patch_resolution
x, c2, c3, c4 = self.feat_slice2([x, c2, c3, c4], [1, 1, 2, 4], clusters, patch_resolution)
bs = x.size(0)
mask_indices, cluster_size_ratio = None, None
# SPM forward,独立的特征金字塔,下采样率为8/16/32
# c: shape(bs, h/8*w/8 + h/16*w/16 + h/32*w/32, ndim), serve as feature
c = torch.cat([c2, c3, c4], dim=1)
# Interaction
for i, layer in enumerate(self.interactions):
indexes = self.interaction_indexes[i]
x, c = layer(x, c, self.layers[indexes[0]:indexes[-1] + 1],
deform_inputs1, deform_inputs2, patch_resolution,
mask_indices=mask_indices, cluster_size_ratio=cluster_size_ratio)
# Split & Reshape
c2 = c[:, 0:c2.size(1), :]
c3 = c[:, c2.size(1):c2.size(1) + c3.size(1), :]
c4 = c[:, c2.size(1) + c3.size(1):, :]
c2 = c2.transpose(1, 2).view(bs, dim, H, W).contiguous()
c3 = c3.transpose(1, 2).view(bs, dim, H // 2, W // 2).contiguous()
c4 = c4.transpose(1, 2).view(bs, dim, H // 4, W // 4).contiguous()
if self.add_vit_feature:
x2 = x.transpose(1, 2).view(bs, dim, H, W).contiguous()
x3 = F.interpolate(x2, scale_factor=0.5, mode='bilinear', align_corners=False)
x4 = F.interpolate(x2, scale_factor=0.25, mode='bilinear', align_corners=False)
c2, c3, c4 = c2 + x2, c3 + x3, c4 + x4
# Final Norm
f2 = self.ad_norm2(c2)
f3 = self.ad_norm3(c3)
f4 = self.ad_norm4(c4)
# torch.cuda.synchronize()
# t0 = time.time()
if mask_indices is not None:
f2, f3, f4 = self.feat_slice([f2, f3, f4], clusters, [1, 2, 4])
# torch.cuda.synchronize()
# t1 = time.time()
# print(f"Feature slicing cost {(t1-t0)*1000:.2f}ms") # 7ms
return [f2, f3, f4]
def forward_org(self, x):
# x: shape(1, 3, 768, 1344)
# 双向Deformable Attention参数
deform_inputs1, deform_inputs2 = deform_inputs(x)
# SPM forward,独立的特征金字塔,下采样率为8/16/32
c2, c3, c4 = self.spm(x) # s4, s8, s16, s32
c2, c3, c4 = self._add_level_embed(c2, c3, c4)
# c: shape(bs, h/8*w/8 + h/16*w/16 + h/32*w/32, ndim), serve as feature
c = torch.cat([c2, c3, c4], dim=1)
B = x.shape[0]
x, patch_resolution = self.patch_embed(x) # 8倍下采样
# x: shape(1, h/8*w/8, ndim), serve as query
H, W = patch_resolution
bs, n, dim = x.shape
pos_embed = resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=0)
x = x + pos_embed
x = self.drop_after_pos(x)
# Interaction
for i, layer in enumerate(self.interactions):
indexes = self.interaction_indexes[i]
x, c = layer(x, c, self.layers[indexes[0]:indexes[-1] + 1],
deform_inputs1, deform_inputs2, patch_resolution)
# Split & Reshape
c2 = c[:, 0:c2.size(1), :]
c3 = c[:, c2.size(1):c2.size(1) + c3.size(1), :]
c4 = c[:, c2.size(1) + c3.size(1):, :]
c2 = c2.transpose(1, 2).view(bs, dim, H, W).contiguous()
c3 = c3.transpose(1, 2).view(bs, dim, H // 2, W // 2).contiguous()
c4 = c4.transpose(1, 2).view(bs, dim, H // 4, W // 4).contiguous()
if self.add_vit_feature:
x2 = x.transpose(1, 2).view(bs, dim, H, W).contiguous()
x3 = F.interpolate(x2, scale_factor=0.5, mode='bilinear', align_corners=False)
x4 = F.interpolate(x2, scale_factor=0.25, mode='bilinear', align_corners=False)
c2, c3, c4 = c2 + x2, c3 + x3, c4 + x4
# Final Norm
f2 = self.ad_norm2(c2)
f3 = self.ad_norm3(c3)
f4 = self.ad_norm4(c4)
return [f2, f3, f4]
class InteractionBlock_GPViT(InteractionBlock):
@staticmethod
def chunk_feat(x: torch.Tensor, mask_indices, patch_resolution, cluster_size_ratio=8):
if mask_indices is None:
return x, patch_resolution
B, L, C = x.shape
ES = cluster_size_ratio
assert L % (ES*ES) == 0 and len(mask_indices) % (L//(ES*ES)) == 0
z = x.view(-1, C)[mask_indices].view(-1, L//(ES*ES), C).contiguous()
H, W = patch_resolution
return z, (H//ES, W//ES)
@staticmethod
def recover_feat(x: torch.Tensor, mask_indices, x0):
if mask_indices is None:
return x
B, L, C = x0.shape
x0.view(-1, C)[mask_indices] = x.view(-1, C).type_as(x0).contiguous()
return x0
def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, patch_resolution,
mask_indices=None, cluster_size_ratio=None):
H, W = patch_resolution
x = x.contiguous()
COUNT_LATENCY = False
if COUNT_LATENCY:
timestamps = []
torch.cuda.synchronize()
timestamps.append(time.time())
x = self.injector(query=x,
reference_points=deform_inputs1[0],
feat=c,
spatial_shapes=deform_inputs1[1],
level_start_index=deform_inputs1[2])
if COUNT_LATENCY:
torch.cuda.synchronize()
timestamps.append(time.time())
x0 = x
x, patch_resolution = \
self.chunk_feat(x, mask_indices, patch_resolution, cluster_size_ratio)
if COUNT_LATENCY:
torch.cuda.synchronize()
timestamps.append(time.time())
if x.size(0) > 0:
for idx, blk in enumerate(blocks):
x = blk(x, patch_resolution)
if COUNT_LATENCY:
torch.cuda.synchronize()
timestamps.append(time.time())
x = self.recover_feat(x, mask_indices, x0)
if COUNT_LATENCY:
torch.cuda.synchronize()
timestamps.append(time.time())
c = self.extractor(query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H, W=W)
if self.extra_extractors is not None:
for extractor in self.extra_extractors:
c = extractor(query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H, W=W)
if COUNT_LATENCY:
torch.cuda.synchronize()
timestamps.append(time.time())
s = ''
for i in range(len(timestamps) - 1):
s += f'{(timestamps[i+1] - timestamps[i])*1000:.2f}ms\t'
print(s)
return x, c
def deform_inputs(x):
bs, c, h, w = x.shape
spatial_shapes = torch.as_tensor([(h // 8, w // 8),
(h // 16, w // 16),
(h // 32, w // 32)],
dtype=torch.long, device=x.device)
# tensor([0, h/8*w/8, h/8*w/8 + h/16*w/16])
level_start_index = torch.cat((
spatial_shapes.new_zeros((1,)),
spatial_shapes.prod(1).cumsum(0)[:-1]
))
# shape(1, h/8*w/8, 1, 2), (xi, yi)从0.5/(h/8)开始,步长为1.0/(h/8)
reference_points = get_reference_points([(h // 8, w // 8)], x.device)
deform_inputs1 = [reference_points, spatial_shapes, level_start_index]
spatial_shapes = torch.as_tensor([(h // 8, w // 8)], dtype=torch.long, device=x.device)
# tensor([0])
level_start_index = torch.cat((
spatial_shapes.new_zeros((1,)),
spatial_shapes.prod(1).cumsum(0)[:-1]
))
# shape(1, h/8*w/8+ h/16*w/16+ h/32*w/32, 1, 2), (xi, yi) = range(0.5/(h/s), 1.0, 1.0/(h/s)),s可变
reference_points = get_reference_points([(h // 8, w // 8),
(h // 16, w // 16),
(h // 32, w // 32)], x.device)
deform_inputs2 = [reference_points, spatial_shapes, level_start_index]
return deform_inputs1, deform_inputs2