in models/swin_transformer_3d.py [0:0]
def forward(self, x, use_checkpoint=False, H=None, W=None, use_seg=False):
"""Forward function.
Args:
x: Input feature, tensor size (B, C, D, H, W).
"""
if use_seg:
return self.forward_seg(x, H, W)
# calculate attention mask for SW-MSA
B, C, D, H, W = x.shape
window_size, shift_size = get_window_size(
(D, H, W), self.window_size, self.shift_size
)
x = rearrange(x, "b c d h w -> b d h w c")
Dp = int(np.ceil(D / window_size[0])) * window_size[0]
Hp = int(np.ceil(H / window_size[1])) * window_size[1]
Wp = int(np.ceil(W / window_size[2])) * window_size[2]
attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device)
for blk in self.blocks:
x = blk(x, attn_mask, use_checkpoint=use_checkpoint)
x = x.view(B, D, H, W, -1)
if self.downsample is not None:
x = self.downsample(x)
x = rearrange(x, "b d h w c -> b c d h w")
return x