def forward()

in timm/models/mvitv2.py [0:0]


    def forward(self, x, feat_size: List[int]):
        B, N, _ = x.shape

        fold_dim = 1 if self.unshared else self.num_heads
        x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3)
        q = k = v = x

        if self.pool_q is not None:
            q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token)
            q = self.pool_q(q)
            q, q_size = reshape_post_pool(q, self.num_heads, q_tok)
        else:
            q_size = feat_size
        if self.norm_q is not None:
            q = self.norm_q(q)

        if self.pool_k is not None:
            k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token)
            k = self.pool_k(k)
            k, k_size = reshape_post_pool(k, self.num_heads, k_tok)
        else:
            k_size = feat_size
        if self.norm_k is not None:
            k = self.norm_k(k)

        if self.pool_v is not None:
            v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token)
            v = self.pool_v(v)
            v, v_size = reshape_post_pool(v, self.num_heads, v_tok)
        else:
            v_size = feat_size
        if self.norm_v is not None:
            v = self.norm_v(v)

        q_N = q_size[0] * q_size[1] + int(self.has_cls_token)
        q = q.transpose(1, 2).reshape(B, q_N, -1)
        q = self.q(q).reshape(B, q_N, self.num_heads, -1).transpose(1, 2)

        k_N = k_size[0] * k_size[1] + int(self.has_cls_token)
        k = k.transpose(1, 2).reshape(B, k_N, -1)
        k = self.k(k).reshape(B, k_N, self.num_heads, -1)

        v_N = v_size[0] * v_size[1] + int(self.has_cls_token)
        v = v.transpose(1, 2).reshape(B, v_N, -1)
        v = self.v(v).reshape(B, v_N, self.num_heads, -1).transpose(1, 2)

        attn = (q * self.scale) @ k
        if self.rel_pos_type == 'spatial':
            attn = cal_rel_pos_type(
                attn,
                q,
                self.has_cls_token,
                q_size,
                k_size,
                self.rel_pos_h,
                self.rel_pos_w,
            )
        attn = attn.softmax(dim=-1)
        x = attn @ v

        if self.residual_pooling:
            x = x + q

        x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
        x = self.proj(x)

        return x, q_size