models/gpvit/adapter_modules.py (323 lines of code) (raw):

""" Reference: https://github.com/czczup/ViT-Adapter Modified: use mmcv version MultiScaleDeformableAttnFunction """ from __future__ import absolute_import, division, print_function import logging from functools import partial import math import warnings import torch import torch.nn.functional as F from torch import nn from torch.nn.init import constant_, xavier_uniform_, normal_, trunc_normal_ from mmcv.runner import force_fp32 from timm.models.layers import DropPath from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttnFunction _logger = logging.getLogger(__name__) def _is_power_of_2(n): if (not isinstance(n, int)) or (n < 0): raise ValueError('invalid input for _is_power_of_2: {} (type: {})'.format(n, type(n))) return (n & (n - 1) == 0) and n != 0 class MSDeformAttn(nn.Module): def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, ratio=1.0): """Multi-Scale Deformable Attention Module. :param d_model hidden dimension :param n_levels number of feature levels :param n_heads number of attention heads :param n_points number of sampling points per attention head per feature level """ super().__init__() if d_model % n_heads != 0: raise ValueError('d_model must be divisible by n_heads, ' 'but got {} and {}'.format(d_model, n_heads)) _d_per_head = d_model // n_heads # you'd better set _d_per_head to a power of 2 # which is more efficient in our CUDA implementation if not _is_power_of_2(_d_per_head) and False: warnings.warn( "You'd better set d_model in MSDeformAttn to make " 'the dimension of each attention head a power of 2 ' 'which is more efficient in our CUDA implementation.') self.im2col_step = 64 self.d_model = d_model self.n_levels = n_levels self.n_heads = n_heads self.n_points = n_points self.ratio = ratio self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) self.value_proj = nn.Linear(d_model, int(d_model * ratio)) self.output_proj = nn.Linear(int(d_model * ratio), d_model) self._reset_parameters() def _reset_parameters(self): constant_(self.sampling_offsets.weight.data, 0.) thetas = torch.arange( self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view( self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) constant_(self.attention_weights.weight.data, 0.) constant_(self.attention_weights.bias.data, 0.) xavier_uniform_(self.value_proj.weight.data) constant_(self.value_proj.bias.data, 0.) xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) @force_fp32(apply_to=('query', 'reference_points', 'input_flatten', 'input_padding_mask')) def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): """ :param query (N, Length_{query}, C) :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements :return output (N, Length_{query}, C) """ N, Len_q, _ = query.shape N, Len_in, _ = input_flatten.shape assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in value = self.value_proj(input_flatten) if input_padding_mask is not None: value = value.masked_fill(input_padding_mask[..., None], float(0)) value = value.view(N, Len_in, self.n_heads, int(self.ratio * self.d_model) // self.n_heads) # 多头,在c维度 sampling_offsets = self.sampling_offsets(query).view( N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) attention_weights = self.attention_weights(query).view( N, Len_q, self.n_heads, self.n_levels * self.n_points) attention_weights = F.softmax(attention_weights, -1).\ view(N, Len_q, self.n_heads, self.n_levels, self.n_points) if reference_points.shape[-1] == 2: offset_normalizer = torch.stack( [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) sampling_locations = reference_points[:, :, None, :, None, :] \ + sampling_offsets / offset_normalizer[None, None, None, :, None, :] elif reference_points.shape[-1] == 4: sampling_locations = reference_points[:, :, None, :, None, :2] \ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 else: raise ValueError( 'Last dim of reference_points must be 2 or 4, but get {} instead.' .format(reference_points.shape[-1])) output = MultiScaleDeformableAttnFunction.apply(value.to(dtype=torch.float32), input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights.to(dtype=torch.float32), self.im2col_step) output = self.output_proj(output.to(attention_weights.dtype)) return output def get_reference_points(spatial_shapes, device): reference_points_list = [] for lvl, (H_, W_) in enumerate(spatial_shapes): # print(H_, W_) # assert (torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device) == torch.arange(0.5, H_, 1.0, dtype=torch.float32, device=device)).all() # assert (torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device) == torch.arange(0.5, W_, 1.0, dtype=torch.float32, device=device)).all() # ref_y, ref_x = torch.meshgrid( # torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), # torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) ref_y, ref_x = torch.meshgrid( torch.arange(0.5, H_, 1., dtype=torch.float32, device=device), torch.arange(0.5, W_, 1., dtype=torch.float32, device=device)) ref_y = ref_y.reshape(-1)[None] / H_ ref_x = ref_x.reshape(-1)[None] / W_ ref = torch.stack((ref_x, ref_y), -1) # shape(1, H_*W_, 2), (x, y) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) # small (& medium & large) reference_points = reference_points[:, :, None] # shape(1, L, 1, 2) return reference_points def deform_inputs(x): bs, c, h, w = x.shape spatial_shapes = torch.as_tensor([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device) level_start_index = torch.cat((spatial_shapes.new_zeros( (1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) reference_points = get_reference_points([(h // 16, w // 16)], x.device) deform_inputs1 = [reference_points, spatial_shapes, level_start_index] spatial_shapes = torch.as_tensor([(h // 16, w // 16)], dtype=torch.long, device=x.device) level_start_index = torch.cat((spatial_shapes.new_zeros( (1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device) deform_inputs2 = [reference_points, spatial_shapes, level_start_index] return deform_inputs1, deform_inputs2 class ConvFFN(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., down_stride=16): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.dwconv = DWConv(hidden_features, down_stride) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x, H, W): x = self.fc1(x) x = self.dwconv(x, H, W) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class DWConv(nn.Module): def __init__(self, dim=768, down_stride=16): super().__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) self.down_stride = down_stride def forward(self, x, H, W): if self.down_stride == 16: B, N, C = x.shape n = N // 21 x1 = x[:, 0:16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous() x2 = x[:, 16 * n:20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous() x3 = x[:, 20 * n:, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous() x1 = self.dwconv(x1).flatten(2).transpose(1, 2) x2 = self.dwconv(x2).flatten(2).transpose(1, 2) x3 = self.dwconv(x3).flatten(2).transpose(1, 2) x = torch.cat([x1, x2, x3], dim=1) return x elif self.down_stride == 8: B, N, C = x.shape n1 = H * W n2 = n1 + H * W // 4 n3 = n2 + H * W // 16 x1 = x[:, 0:n1, :].transpose(1, 2).view(B, C, H, W).contiguous() x2 = x[:, n1:n2, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous() x3 = x[:, n2:, :].transpose(1, 2).view(B, C, H // 4, W // 4).contiguous() x1 = self.dwconv(x1).flatten(2).transpose(1, 2) x2 = self.dwconv(x2).flatten(2).transpose(1, 2) x3 = self.dwconv(x3).flatten(2).transpose(1, 2) x = torch.cat([x1, x2, x3], dim=1) return x else: raise NotImplementedError class Extractor(nn.Module): def __init__(self, dim, num_heads=6, n_points=4, n_levels=1, deform_ratio=1.0, with_cffn=True, cffn_ratio=0.25, drop=0., drop_path=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), down_stride=16): super().__init__() self.query_norm = norm_layer(dim) self.feat_norm = norm_layer(dim) self.attn = MSDeformAttn(d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio) self.with_cffn = with_cffn if with_cffn: self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop, down_stride=down_stride) self.ffn_norm = norm_layer(dim) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W): attn = self.attn(self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None) query = query + attn if self.with_cffn: query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W)) return query class Injector(nn.Module): def __init__(self, dim, num_heads=6, n_points=4, n_levels=1, deform_ratio=1.0, norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=0.): super().__init__() self.query_norm = norm_layer(dim) self.feat_norm = norm_layer(dim) self.attn = MSDeformAttn(d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio) self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) def forward(self, query, reference_points, feat, spatial_shapes, level_start_index): attn = self.attn(self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None) return query + self.gamma * attn class InteractionBlock(nn.Module): def __init__(self, dim, num_heads=6, n_points=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), drop=0., drop_path=0., with_cffn=True, cffn_ratio=0.25, init_values=0., deform_ratio=1.0, extra_extractor=False, down_stride=16): super().__init__() self.injector = Injector(dim=dim, n_levels=3, num_heads=num_heads, init_values=init_values, n_points=n_points, norm_layer=norm_layer, deform_ratio=deform_ratio) self.extractor = Extractor(dim=dim, n_levels=1, num_heads=num_heads, n_points=n_points, norm_layer=norm_layer, deform_ratio=deform_ratio, with_cffn=with_cffn, cffn_ratio=cffn_ratio, drop=drop, drop_path=drop_path, down_stride=down_stride) if extra_extractor: self.extra_extractors = nn.Sequential(*[ Extractor(dim=dim, num_heads=num_heads, n_points=n_points, norm_layer=norm_layer, with_cffn=with_cffn, cffn_ratio=cffn_ratio, deform_ratio=deform_ratio, drop=drop, drop_path=drop_path, down_stride=down_stride) for _ in range(2) ]) else: self.extra_extractors = None def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H, W): x = self.injector(query=x, reference_points=deform_inputs1[0], feat=c, spatial_shapes=deform_inputs1[1], level_start_index=deform_inputs1[2]) for idx, blk in enumerate(blocks): x = blk(x, H, W) c = self.extractor(query=c, reference_points=deform_inputs2[0], feat=x, spatial_shapes=deform_inputs2[1], level_start_index=deform_inputs2[2], H=H, W=W) if self.extra_extractors is not None: for extractor in self.extra_extractors: c = extractor(query=c, reference_points=deform_inputs2[0], feat=x, spatial_shapes=deform_inputs2[1], level_start_index=deform_inputs2[2], H=H, W=W) return x, c class SpatialPriorModule(nn.Module): def __init__(self, inplanes=64, embed_dim=384, out_c1=True): super().__init__() self.stem = nn.Sequential(*[ nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(inplanes), nn.ReLU(inplace=True), nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(inplanes), nn.ReLU(inplace=True), nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(inplanes), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1) ]) # s4 self.conv2 = nn.Sequential(*[ nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(2 * inplanes), nn.ReLU(inplace=True) ]) # s8 self.conv3 = nn.Sequential(*[ nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(4 * inplanes), nn.ReLU(inplace=True) ]) # s16 self.conv4 = nn.Sequential(*[ nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(4 * inplanes), nn.ReLU(inplace=True) ]) # s32 if out_c1: self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) self.out_c1 = out_c1 self.embed_dim = embed_dim self.level_embed = nn.Parameter(torch.zeros(3, embed_dim)) self.apply(self._init_weights) normal_(self.level_embed) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x): c1 = self.stem(x) c2 = self.conv2(c1) c3 = self.conv3(c2) c4 = self.conv4(c3) if self.out_c1: c1 = self.fc1(c1) c2 = self.fc2(c2) c3 = self.fc3(c3) c4 = self.fc4(c4) bs, dim, _, _ = c1.shape # c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s c2 = c2.view(bs, self.embed_dim, -1).transpose(1, 2) # 8s c3 = c3.view(bs, self.embed_dim, -1).transpose(1, 2) # 16s c4 = c4.view(bs, self.embed_dim, -1).transpose(1, 2) # 32s # _add_level_embed c2 = c2 + self.level_embed[0] c3 = c3 + self.level_embed[1] c4 = c4 + self.level_embed[2] if self.out_c1: return c1, c2, c3, c4 else: return c2, c3, c4