in slowfast/models/attention.py [0:0]
def forward(self, x, thw_shape):
B, N, C = x.shape
if self.pool_first:
x = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(
0, 2, 1, 3
)
q = k = v = x
else:
q = k = v = x
q = (
self.q(q)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
k = (
self.k(k)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
v = (
self.v(v)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
q, q_shape = attention_pool(
q,
self.pool_q,
thw_shape,
has_cls_embed=self.has_cls_embed,
norm=self.norm_q if hasattr(self, "norm_q") else None,
)
k, k_shape = attention_pool(
k,
self.pool_k,
thw_shape,
has_cls_embed=self.has_cls_embed,
norm=self.norm_k if hasattr(self, "norm_k") else None,
)
v, v_shape = attention_pool(
v,
self.pool_v,
thw_shape,
has_cls_embed=self.has_cls_embed,
norm=self.norm_v if hasattr(self, "norm_v") else None,
)
if self.pool_first:
q_N = (
numpy.prod(q_shape) + 1
if self.has_cls_embed
else numpy.prod(q_shape)
)
k_N = (
numpy.prod(k_shape) + 1
if self.has_cls_embed
else numpy.prod(k_shape)
)
v_N = (
numpy.prod(v_shape) + 1
if self.has_cls_embed
else numpy.prod(v_shape)
)
q = q.permute(0, 2, 1, 3).reshape(B, q_N, C)
q = (
self.q(q)
.reshape(B, q_N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
v = v.permute(0, 2, 1, 3).reshape(B, v_N, C)
v = (
self.v(v)
.reshape(B, v_N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
k = k.permute(0, 2, 1, 3).reshape(B, k_N, C)
k = (
self.k(k)
.reshape(B, k_N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
N = q.shape[2]
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
if self.drop_rate > 0.0:
x = self.proj_drop(x)
return x, q_shape