in models/transformer.py [0:0]
def forward(self, src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
xyz: Optional [Tensor] = None,
transpose_swap: Optional[bool] = False,
):
if transpose_swap:
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
if pos is not None:
pos = pos.flatten(2).permute(2, 0, 1)
output = src
xyz_dist = None
xyz_inds = None
for idx, layer in enumerate(self.layers):
mask = None
if self.masking_radius[idx] > 0:
mask, xyz_dist = self.compute_mask(xyz, self.masking_radius[idx], xyz_dist)
# mask must be tiled to num_heads of the transformer
bsz, n, n = mask.shape
nhead = layer.nhead
mask = mask.unsqueeze(1)
mask = mask.repeat(1, nhead, 1, 1)
mask = mask.view(bsz * nhead, n, n)
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
if idx == 0 and self.interim_downsampling:
# output is npoints x batch x channel. make batch x channel x npoints
output = output.permute(1, 2, 0)
xyz, output, xyz_inds = self.interim_downsampling(xyz, output)
# swap back
output = output.permute(2, 0, 1)
if self.norm is not None:
output = self.norm(output)
if transpose_swap:
output = output.permute(1, 2, 0).view(bs, c, h, w).contiguous()
return xyz, output, xyz_inds