in models/transformer.py [0:0]
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
transpose_swap: Optional [bool] = False,
return_attn_weights: Optional [bool] = False,
):
if transpose_swap:
bs, c, h, w = memory.shape
memory = memory.flatten(2).permute(2, 0, 1) # memory: bs, c, t -> t, b, c
if pos is not None:
pos = pos.flatten(2).permute(2, 0, 1)
output = tgt
intermediate = []
attns = []
for layer in self.layers:
output, attn = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos,
return_attn_weights=return_attn_weights)
if self.return_intermediate:
intermediate.append(self.norm(output))
if return_attn_weights:
attns.append(attn)
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if return_attn_weights:
attns = torch.stack(attns)
if self.return_intermediate:
return torch.stack(intermediate), attns
return output, attns