in timm/models/mvitv2.py [0:0]
def forward(self, x, feat_size: List[int]):
B, N, _ = x.shape
fold_dim = 1 if self.unshared else self.num_heads
x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3)
q = k = v = x
if self.pool_q is not None:
q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token)
q = self.pool_q(q)
q, q_size = reshape_post_pool(q, self.num_heads, q_tok)
else:
q_size = feat_size
if self.norm_q is not None:
q = self.norm_q(q)
if self.pool_k is not None:
k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token)
k = self.pool_k(k)
k, k_size = reshape_post_pool(k, self.num_heads, k_tok)
else:
k_size = feat_size
if self.norm_k is not None:
k = self.norm_k(k)
if self.pool_v is not None:
v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token)
v = self.pool_v(v)
v, v_size = reshape_post_pool(v, self.num_heads, v_tok)
else:
v_size = feat_size
if self.norm_v is not None:
v = self.norm_v(v)
q_N = q_size[0] * q_size[1] + int(self.has_cls_token)
q = q.transpose(1, 2).reshape(B, q_N, -1)
q = self.q(q).reshape(B, q_N, self.num_heads, -1).transpose(1, 2)
k_N = k_size[0] * k_size[1] + int(self.has_cls_token)
k = k.transpose(1, 2).reshape(B, k_N, -1)
k = self.k(k).reshape(B, k_N, self.num_heads, -1)
v_N = v_size[0] * v_size[1] + int(self.has_cls_token)
v = v.transpose(1, 2).reshape(B, v_N, -1)
v = self.v(v).reshape(B, v_N, self.num_heads, -1).transpose(1, 2)
attn = (q * self.scale) @ k
if self.rel_pos_type == 'spatial':
attn = cal_rel_pos_type(
attn,
q,
self.has_cls_token,
q_size,
k_size,
self.rel_pos_h,
self.rel_pos_w,
)
attn = attn.softmax(dim=-1)
x = attn @ v
if self.residual_pooling:
x = x + q
x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
x = self.proj(x)
return x, q_size