in models/swin_transformer_3d.py [0:0]
def forward(self, x, mask_matrix, use_checkpoint=False):
"""Forward function.
Args:
x: Input feature, tensor size (B, D, H, W, C).
mask_matrix: Attention mask for cyclic shift.
"""
shortcut = x
if use_checkpoint:
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
else:
x = self.forward_part1(x, mask_matrix)
x = shortcut + self.drop_path(x)
if use_checkpoint:
x = x + checkpoint.checkpoint(self.forward_part2, x)
else:
x = x + self.forward_part2(x)
return x