in models/transformer.py [0:0]
def forward(self, src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
xyz: Optional [Tensor] = None,
transpose_swap: Optional[bool] = False,
):
if transpose_swap:
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
if pos is not None:
pos = pos.flatten(2).permute(2, 0, 1)
output = src
orig_mask = mask
if orig_mask is not None and isinstance(orig_mask, list):
assert len(orig_mask) == len(self.layers)
elif orig_mask is not None:
orig_mask = [mask for _ in range(len(self.layers))]
for idx, layer in enumerate(self.layers):
if orig_mask is not None:
mask = orig_mask[idx]
# mask must be tiled to num_heads of the transformer
bsz, n, n = mask.shape
nhead = layer.nhead
mask = mask.unsqueeze(1)
mask = mask.repeat(1, nhead, 1, 1)
mask = mask.view(bsz * nhead, n, n)
output = layer(output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, pos=pos)
if self.norm is not None:
output = self.norm(output)
if transpose_swap:
output = output.permute(1, 2, 0).view(bs, c, h, w).contiguous()
xyz_inds = None
return xyz, output, xyz_inds