in slowfast/models/vit_helper.py [0:0]
def forward(self, x, seq_len=196, num_frames=8, approx='none', num_landmarks=128):
B, N, C = x.shape
P = seq_len
F = num_frames
h = self.num_heads
# project x to q, k, v vaalues
q, k, v = self.qkv(x).chunk(3, dim=-1)
# Reshape: 'b n (h d) -> (b h) n d'
q, k, v = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
# remove CLS token from q, k, v
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(
lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v))
# let CLS token attend to key / values of all patches across time and space
cls_out = qkv_attn(cls_q * self.scale, k, v)
cls_out = rearrange(cls_out, f'(b h) f d -> b f (h d)', f=1, h=h)
if approx == "nystrom":
## Shared spatial landmarks
q_, k_, v_ = map(
lambda t: rearrange(t, f'b h p d -> (b h) p d', h=h), (q_, k_, v_))
x = nystrom_helper.nystrom_spatial_attn(
q_, k_, v_,
landmarks=num_landmarks,
num_frames=F,
inv_iters=6,
use_spatial_landmarks=True
)
x = rearrange(x, f'(b h) p f d -> b h p f d', f=F, h=h)
elif approx == "orthoformer":
x = orthoformer_helper.orthoformer(
q_, k_, v_,
num_landmarks=num_landmarks,
num_frames=F,
)
elif approx == "performer":
# Form random projection matrices:
m = 256 # r = 2m, m <= d
d = self.head_dim
seed = torch.ceil(torch.abs(torch.sum(q_) * performer_helper.BIG_CONSTANT))
seed = torch.tensor(seed)
projection_matrix = performer_helper.create_projection_matrix(
m, d, seed=seed, device=q_.device, dtype=q_.dtype)
q_, k_ = map(lambda t: rearrange(t, f'b h p d -> b p h d'), (q_, k_))
q_prime = performer_helper.softmax_kernel_transformation(
q_,
is_query=True,
projection_matrix=projection_matrix
)
k_prime = performer_helper.softmax_kernel_transformation(
k_,
is_query=False,
projection_matrix=projection_matrix
)
q_prime, k_prime = map(
lambda t: rearrange(t, f'b p h r -> b h p r'), (q_prime, k_prime))
k_prime = rearrange(k_prime, 'b h (f n) r -> b h f n r', f=F)
v_ = rearrange(v_, 'b h (f n) d -> b h f n d', f=F)
kv = torch.einsum('b h f n r, b h f n d -> b h f r d', k_prime, v_)
qkv = torch.einsum('b h p r, b h f r d -> b h p f d', q_prime, kv)
normaliser = torch.einsum('b h f n r -> b h f r', k_prime)
normaliser = torch.einsum('b h p r, b h f r -> b h p f', q_prime, normaliser)
x = qkv / normaliser.unsqueeze(-1)
else:
# Using full attention
q_dot_k = q_ @ k_.transpose(-2, -1)
q_dot_k = rearrange(q_dot_k, 'b q (f n) -> b q f n', f=F)
space_attn = (self.scale * q_dot_k).softmax(dim=-1)
attn = self.attn_drop(space_attn)
v_ = rearrange(v_, 'b (f n) d -> b f n d', f=F, n=P)
x = torch.einsum('b q f n, b f n d -> b q f d', attn, v_)
# Temporal attention: query is the similarity-aggregated patch
x = rearrange(x, '(b h) s f d -> b s f (h d)', b=B)
x_diag = rearrange(x, 'b (g n) f d -> b g n f d', g=F)
x_diag = torch.diagonal(x_diag, dim1=-4, dim2=-2)
x_diag = rearrange(x_diag, f'b n d f -> b (f n) d', f=F)
q2 = self.proj_q(x_diag)
k2, v2 = self.proj_kv(x).chunk(2, dim=-1)
q2 = rearrange(q2, f'b s (h d) -> b h s d', h=h)
q2 *= self.scale
k2, v2 = map(
lambda t: rearrange(t, f'b s f (h d) -> b h s f d', f=F, h=h), (k2, v2))
attn = torch.einsum('b h s d, b h s f d -> b h s f', q2, k2)
attn = attn.softmax(dim=-1)
if self.use_original_code:
x = rearrange(x, f'b s f (h d) -> b h s f d', f=F, h=h)
x = torch.einsum('b h s f, b h s f d -> b h s d', attn, x)
else:
x = torch.einsum('b h s f, b h s f d -> b h s d', attn, v2)
x = rearrange(x, f'b h s d -> b s (h d)')
# concat back the cls token
x = torch.cat((cls_out, x), dim=1)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn