in models/swin_transformer_3d.py [0:0]
def get_patch_embedding(self, x):
# x: B x C x T x H x W
assert x.ndim == 5
has_depth = x.shape[1] == 4
if has_depth:
if self.depth_mode in ["summed_rgb_d_tokens"]:
x_rgb = x[:, :3, ...]
x_d = x[:, 3:, ...]
x_d = self.depth_patch_embed(x_d)
x_rgb = self.patch_embed(x_rgb)
# sum the two sets of tokens
x = x_rgb + x_d
elif self.depth_mode == "rgbd":
if self.depth_patch_embed_separate_params:
x = self.depth_patch_embed(x)
else:
x = self.patch_embed(x)
else:
logging.info("Depth mode %s not supported" % self.depth_mode)
raise NotImplementedError()
else:
x = self.patch_embed(x)
return x