in models/swin_transformer_3d.py [0:0]
def forward_seg(self, x, H, W):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size[1])) * self.window_size[1]
Wp = int(np.ceil(W / self.window_size[2])) * self.window_size[2]
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (
slice(0, -self.window_size[1]),
slice(-self.window_size[1], -self.shift_size[1]),
slice(-self.shift_size[1], None),
)
w_slices = (
slice(0, -self.window_size[2]),
slice(-self.window_size[2], -self.shift_size[2]),
slice(-self.shift_size[2], None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition_image(
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size[1] * self.window_size[2])
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
for blk in self.blocks:
blk.H, blk.W = H, W
if x.ndim == 4:
B, D, C, L = x.shape
assert L == H * W, "input feature has wrong size"
x = x.reshape(B, D, C, H, W)
x = x.permute(0, 1, 3, 4, 2)
assert x.shape[2] == H
assert x.shape[3] == W
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W