def forward()

in src/transformers/models/vidtr/multihead_attention.py [0:0]


    def forward(self, q, k, value, orig_shape, attn_mask=None, key_padding_mask=None):
        mask = attn_mask

        b, c, t, w, h = orig_shape

        seq_l, sz_b, c = q.shape

        qkv = q @ self.in_proj_weight.t() + self.in_proj_bias

        q = qkv[:, :, :c]
        k = qkv[:, :, c: 2 * c]
        v = qkv[:, :, 2 * c:]

        q_t = q.view(t + 1, w * h + 1, b, self.head_num, c // self.head_num).permute(3, 2, 1, 0, 4).contiguous().view(-1, t + 1,
                                                                                                              self.d_model // self.head_num)
        k_t = k.view(t + 1, w * h + 1, b, self.head_num, c // self.head_num).permute(3, 2, 1, 0, 4).contiguous().view(-1, t + 1,
                                                                                                              self.d_model // self.head_num)

        v_t = v.view(t + 1, w * h + 1, b, self.head_num, c // self.head_num).permute(3, 2, 1, 0, 4).contiguous().view(-1, t + 1,
                                                                                                              self.d_model // self.head_num)

        if self.pool:
            output_t, attnx, idx = self.attention_t(q_t, k_t, v_t, mask=mask)

            idx_ = idx[:, :, :1].repeat(1, 1, c // self.head_num)

            v_s = output_t.view(self.head_num, b, w * h + 1, t - self.k + 1, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(
                self.head_num * sz_b * (t - self.k + 1), w * h + 1, -1)

            q_cls = q_t[:, :1, :]
            q_t = q_t[:, 1:, :]
            q_t = q_t.gather(1, idx_)
            q_t = torch.cat((q_cls, q_t), dim=1)

            k_cls = k_t[:, :1, :]
            k_t = k_t[:, 1:, :]
            k_t = k_t.gather(1, idx_)
            k_t = torch.cat((k_cls, k_t), dim=1)

            q_s = q_t.view(self.head_num, b, w * h + 1, t + 1 - self.k, self.d_model // self.head_num).permute(0, 1, 3, 2,
                                                                                                       4).contiguous().view(
                self.head_num * sz_b * (t + 1 - self.k), w * h + 1, -1)
            k_s = k_t.view(self.head_num, b, w * h + 1, t + 1 - self.k, self.d_model // self.head_num).permute(0, 1, 3, 2,
                                                                                                       4).contiguous().view(
                self.head_num * sz_b * (t + 1 - self.k), w * h + 1, -1)
        else:
            output_t, attnx = self.attention_t(q_t, k_t, v_t, mask=mask)
            idx = None
            v_s = output_t.view(self.head_num, b, w * h + 1, t + 1, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(self.head_num * sz_b * (t + 1), w * h + 1, -1)

            q_s = q_t.view(self.head_num, b, w * h + 1, t + 1, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(
                self.head_num * sz_b * (t + 1), w * h + 1, -1)
            k_s = k_t.view(self.head_num, b, w * h + 1, t + 1, self.d_model // self.head_num).permute(0, 1, 3, 2, 4).contiguous().view(
                self.head_num * sz_b * (t + 1), w * h + 1, -1)

        output_s, attn = self.attention_s(q_s, k_s, v_s, mask=mask)
        _, seq_l, _ = output_s.shape
        output = output_s.view(self.head_num, b, -1, w * h + 1, self.d_model // self.head_num).permute(2, 3, 1, 0, 4).contiguous().view(-1, b, c)

        output = self.out_proj(output)

        return output, attn, idx