def forward()

in src/transformers/models/vidtr/multihead_attention.py [0:0]


    def forward(self, q, k, value, orig_shape, attn_mask=None, key_padding_mask=None):
        mask = attn_mask

        b, c, t, w, h = orig_shape

        seq_l, sz_b, c = q.shape

        q_cls = q[:1, :, :]
        q = q[1:, :, :]

        qkv = q @ self.in_proj_weight.t() + self.in_proj_bias

        q = qkv[:, :, :c]
        k = qkv[:, :, c: 2 * c]
        v = qkv[:, :, 2 * c:]

        q_t = q.view(t, w * h, b, self.head_num, c // self.head_num).permute(3, 2, 1, 0, 4).contiguous().view(-1, t,
                                                                                                              self.d_model // self.head_num)
        k_t = k.view(t, w * h, b, self.head_num, c // self.head_num).permute(3, 2, 1, 0, 4).contiguous().view(-1, t,
                                                                                                              self.d_model // self.head_num)

        v_t = v.view(t, w * h, b, self.head_num, c // self.head_num).permute(3, 2, 1, 0, 4).contiguous().view(-1, t,
                                                                                                              self.d_model // self.head_num)

        if self.pool:
            output_t, attnx, idx = self.attention_t(q_t, k_t, v_t, mask=mask)
            v_s = output_t.view(self.head_num, b, w * h, t - self.k, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(
                self.head_num * sz_b * (t - self.k), w * h, -1)

            idx_ = idx[:, :, :1].repeat(1, 1, c // self.head_num)
            q_t = q_t.gather(1, idx_)
            k_t = k_t.gather(1, idx_)

            q_s = q_t.view(self.head_num, b, w * h, t - self.k, self.d_model // self.head_num).permute(0, 1, 3, 2,
                                                                                                       4).contiguous().view(
                self.head_num * sz_b * (t - self.k), w * h, -1)
            k_s = k_t.view(self.head_num, b, w * h, t - self.k, self.d_model // self.head_num).permute(0, 1, 3, 2,
                                                                                                       4).contiguous().view(
                self.head_num * sz_b * (t - self.k), w * h, -1)
        else:
            output_t, attnx = self.attention_t(q_t, k_t, v_t, mask=mask)
            idx = None
            v_s = output_t.view(self.head_num, b, w * h, t, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(self.head_num * sz_b * t, w * h, -1)

            q_s = q_t.view(self.head_num, b, w * h, t, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(
                self.head_num * sz_b * t, w * h, -1)
            k_s = k_t.view(self.head_num, b, w * h, t, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(
                self.head_num * sz_b * t, w * h, -1)

        output_s, attn = self.attention_s(q_s, k_s, v_s, mask=mask)
        _, seq_l, _ = output_s.shape
        output = output_s.view(self.head_num, b, -1, w * h, self.d_model // self.head_num).permute(2, 3, 1, 0, 4).contiguous().view(-1, b, c)

        output = self.out_proj(output)
        output = torch.cat((q_cls, output), dim=0)

        return output, attn, idx