in src/transformers/models/vidtr/multihead_attention.py [0:0]
def forward(self, q, k, value, orig_shape, attn_mask=None, key_padding_mask=None):
mask = attn_mask
b, c, t, w, h = orig_shape
seq_l, sz_b, c = q.shape
q_cls = q[:1, :, :]
q = q[1:, :, :]
qkv = q @ self.in_proj_weight.t() + self.in_proj_bias
q = qkv[:, :, :c]
k = qkv[:, :, c: 2 * c]
v = qkv[:, :, 2 * c:]
q_t = q.view(t, w * h, b, self.head_num, c // self.head_num).permute(3, 2, 1, 0, 4).contiguous().view(-1, t,
self.d_model // self.head_num)
k_t = k.view(t, w * h, b, self.head_num, c // self.head_num).permute(3, 2, 1, 0, 4).contiguous().view(-1, t,
self.d_model // self.head_num)
v_t = v.view(t, w * h, b, self.head_num, c // self.head_num).permute(3, 2, 1, 0, 4).contiguous().view(-1, t,
self.d_model // self.head_num)
if self.pool:
output_t, attnx, idx = self.attention_t(q_t, k_t, v_t, mask=mask)
v_s = output_t.view(self.head_num, b, w * h, t - self.k, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(
self.head_num * sz_b * (t - self.k), w * h, -1)
idx_ = idx[:, :, :1].repeat(1, 1, c // self.head_num)
q_t = q_t.gather(1, idx_)
k_t = k_t.gather(1, idx_)
q_s = q_t.view(self.head_num, b, w * h, t - self.k, self.d_model // self.head_num).permute(0, 1, 3, 2,
4).contiguous().view(
self.head_num * sz_b * (t - self.k), w * h, -1)
k_s = k_t.view(self.head_num, b, w * h, t - self.k, self.d_model // self.head_num).permute(0, 1, 3, 2,
4).contiguous().view(
self.head_num * sz_b * (t - self.k), w * h, -1)
else:
output_t, attnx = self.attention_t(q_t, k_t, v_t, mask=mask)
idx = None
v_s = output_t.view(self.head_num, b, w * h, t, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(self.head_num * sz_b * t, w * h, -1)
q_s = q_t.view(self.head_num, b, w * h, t, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(
self.head_num * sz_b * t, w * h, -1)
k_s = k_t.view(self.head_num, b, w * h, t, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(
self.head_num * sz_b * t, w * h, -1)
output_s, attn = self.attention_s(q_s, k_s, v_s, mask=mask)
_, seq_l, _ = output_s.shape
output = output_s.view(self.head_num, b, -1, w * h, self.d_model // self.head_num).permute(2, 3, 1, 0, 4).contiguous().view(-1, b, c)
output = self.out_proj(output)
output = torch.cat((q_cls, output), dim=0)
return output, attn, idx