def forward_pre()

in src/transformers/models/vidtr/vidtr_compact.py [0:0]


    def forward_pre(self, src, orig_shape,
                    src_mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None):
        b, c, t, w, h = orig_shape
        src2 = self.norm1(src)
        if self.layer_index == 0:
            q = k = v = self.with_pos_embed(src2, pos)
        else:
            q = k = v = src2

        if self.layer_index <= self.layer_pool[-1]:
            src_attn = self.self_attn(q, k, value=v, orig_shape=orig_shape, attn_mask=src_mask,
                                      key_padding_mask=src_key_padding_mask)
        else:
            src_attn = self.self_attn(q, k, value=v, attn_mask=src_mask,
                                      key_padding_mask=src_key_padding_mask)

        src2 = src_attn[0]

        if self.layer_index in self.layer_pool:
            idx = src_attn[2]
            idx = idx[:, :, :1].repeat(1, 1, c // self.nhead)

            b, c, t, w, h = orig_shape

            src_cls = src[:1, :, :]
            src_twh = src[1:, :, :]

            # by index
            src_twh = src_twh.view(t, w * h, b, self.nhead, c // self.nhead).permute(3, 2, 1, 0,
                                                                                     4).contiguous().view(
                self.nhead * b * (w * h), t, c // self.nhead)
            src_twh = src_twh.gather(1, idx)
            src_twh = src_twh.view(self.nhead, b, w * h, t - self.number_of_keys[self.layer_pool.index(self.layer_index)],
                                   c // self.nhead).permute(3, 2, 1, 0, 4).contiguous().view(-1, b, c)

            src = torch.cat((src_cls, src_twh), dim=0)
            orig_shape = (b, c, t - self.number_of_keys[self.layer_pool.index(self.layer_index)], w, h)

        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src, pos, orig_shape