in src/transformers/models/vidtr/multihead_attention.py [0:0]
def forward(self, q, k, value, orig_shape, attn_mask=None, key_padding_mask=None):
mask = attn_mask
b, c, t, w, h = orig_shape
seq_l, sz_b, c = q.shape
qkv = q @ self.in_proj_weight.t() + self.in_proj_bias
q = qkv[:, :, :c]
k = qkv[:, :, c: 2 * c]
v = qkv[:, :, 2 * c:]
q_t = q.view(t + 1, w * h + 1, b, self.head_num, c // self.head_num).permute(3, 2, 1, 0, 4).contiguous().view(-1, t + 1,
self.d_model // self.head_num)
k_t = k.view(t + 1, w * h + 1, b, self.head_num, c // self.head_num).permute(3, 2, 1, 0, 4).contiguous().view(-1, t + 1,
self.d_model // self.head_num)
v_t = v.view(t + 1, w * h + 1, b, self.head_num, c // self.head_num).permute(3, 2, 1, 0, 4).contiguous().view(-1, t + 1,
self.d_model // self.head_num)
if self.pool:
output_t, attnx, idx = self.attention_t(q_t, k_t, v_t, mask=mask)
idx_ = idx[:, :, :1].repeat(1, 1, c // self.head_num)
v_s = output_t.view(self.head_num, b, w * h + 1, t - self.k + 1, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(
self.head_num * sz_b * (t - self.k + 1), w * h + 1, -1)
q_cls = q_t[:, :1, :]
q_t = q_t[:, 1:, :]
q_t = q_t.gather(1, idx_)
q_t = torch.cat((q_cls, q_t), dim=1)
k_cls = k_t[:, :1, :]
k_t = k_t[:, 1:, :]
k_t = k_t.gather(1, idx_)
k_t = torch.cat((k_cls, k_t), dim=1)
q_s = q_t.view(self.head_num, b, w * h + 1, t + 1 - self.k, self.d_model // self.head_num).permute(0, 1, 3, 2,
4).contiguous().view(
self.head_num * sz_b * (t + 1 - self.k), w * h + 1, -1)
k_s = k_t.view(self.head_num, b, w * h + 1, t + 1 - self.k, self.d_model // self.head_num).permute(0, 1, 3, 2,
4).contiguous().view(
self.head_num * sz_b * (t + 1 - self.k), w * h + 1, -1)
else:
output_t, attnx = self.attention_t(q_t, k_t, v_t, mask=mask)
idx = None
v_s = output_t.view(self.head_num, b, w * h + 1, t + 1, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(self.head_num * sz_b * (t + 1), w * h + 1, -1)
q_s = q_t.view(self.head_num, b, w * h + 1, t + 1, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(
self.head_num * sz_b * (t + 1), w * h + 1, -1)
k_s = k_t.view(self.head_num, b, w * h + 1, t + 1, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(
self.head_num * sz_b * (t + 1), w * h + 1, -1)
output_s, attn = self.attention_s(q_s, k_s, v_s, mask=mask)
_, seq_l, _ = output_s.shape
output = output_s.view(self.head_num, b, -1, w * h + 1, self.d_model // self.head_num).permute(2, 3, 1, 0, 4).contiguous().view(-1, b, c)
output = self.out_proj(output)
return output, attn, idx