in src/transformers/models/vidtr/vidtr_compact.py [0:0]
def forward_pre(self, src, orig_shape,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
b, c, t, w, h = orig_shape
src2 = self.norm1(src)
if self.layer_index == 0:
q = k = v = self.with_pos_embed(src2, pos)
else:
q = k = v = src2
if self.layer_index <= self.layer_pool[-1]:
src_attn = self.self_attn(q, k, value=v, orig_shape=orig_shape, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)
else:
src_attn = self.self_attn(q, k, value=v, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)
src2 = src_attn[0]
if self.layer_index in self.layer_pool:
idx = src_attn[2]
idx = idx[:, :, :1].repeat(1, 1, c // self.nhead)
b, c, t, w, h = orig_shape
src_cls = src[:1, :, :]
src_twh = src[1:, :, :]
# by index
src_twh = src_twh.view(t, w * h, b, self.nhead, c // self.nhead).permute(3, 2, 1, 0,
4).contiguous().view(
self.nhead * b * (w * h), t, c // self.nhead)
src_twh = src_twh.gather(1, idx)
src_twh = src_twh.view(self.nhead, b, w * h, t - self.number_of_keys[self.layer_pool.index(self.layer_index)],
c // self.nhead).permute(3, 2, 1, 0, 4).contiguous().view(-1, b, c)
src = torch.cat((src_cls, src_twh), dim=0)
orig_shape = (b, c, t - self.number_of_keys[self.layer_pool.index(self.layer_index)], w, h)
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src, pos, orig_shape