def forward()

in models/swin_transformer_3d.py [0:0]


    def forward(self, x, use_checkpoint=False, H=None, W=None, use_seg=False):
        """Forward function.

        Args:
            x: Input feature, tensor size (B, C, D, H, W).
        """
        if use_seg:
            return self.forward_seg(x, H, W)
        # calculate attention mask for SW-MSA
        B, C, D, H, W = x.shape
        window_size, shift_size = get_window_size(
            (D, H, W), self.window_size, self.shift_size
        )
        x = rearrange(x, "b c d h w -> b d h w c")
        Dp = int(np.ceil(D / window_size[0])) * window_size[0]
        Hp = int(np.ceil(H / window_size[1])) * window_size[1]
        Wp = int(np.ceil(W / window_size[2])) * window_size[2]
        attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device)
        for blk in self.blocks:
            x = blk(x, attn_mask, use_checkpoint=use_checkpoint)
        x = x.view(B, D, H, W, -1)

        if self.downsample is not None:
            x = self.downsample(x)
        x = rearrange(x, "b d h w c -> b c d h w")
        return x