in models/swin_transformer_3d.py [0:0]
def forward_part1(self, x, mask_matrix):
B, D, H, W, C = x.shape
window_size, shift_size = get_window_size(
(D, H, W), self.window_size, self.shift_size
)
x = self.norm1(x)
# pad feature maps to multiples of window size
pad_l = pad_t = pad_d0 = 0
pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0]
pad_b = (window_size[1] - H % window_size[1]) % window_size[1]
pad_r = (window_size[2] - W % window_size[2]) % window_size[2]
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
_, Dp, Hp, Wp, _ = x.shape
# cyclic shift
if any(i > 0 for i in shift_size):
shifted_x = torch.roll(
x,
shifts=(-shift_size[0], -shift_size[1], -shift_size[2]),
dims=(1, 2, 3),
)
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C
# merge windows
attn_windows = attn_windows.view(-1, *(window_size + (C,)))
shifted_x = window_reverse(
attn_windows, window_size, B, Dp, Hp, Wp
) # B D' H' W' C
# reverse cyclic shift
if any(i > 0 for i in shift_size):
x = torch.roll(
shifted_x,
shifts=(shift_size[0], shift_size[1], shift_size[2]),
dims=(1, 2, 3),
)
else:
x = shifted_x
if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
x = x[:, :D, :H, :W, :].contiguous()
return x