def get_patch_embedding()

in models/swin_transformer_3d.py [0:0]


    def get_patch_embedding(self, x):
        # x: B x C x T x H x W
        assert x.ndim == 5
        has_depth = x.shape[1] == 4

        if has_depth:
            if self.depth_mode in ["summed_rgb_d_tokens"]:
                x_rgb = x[:, :3, ...]
                x_d = x[:, 3:, ...]
                x_d = self.depth_patch_embed(x_d)
                x_rgb = self.patch_embed(x_rgb)
                # sum the two sets of tokens
                x = x_rgb + x_d
            elif self.depth_mode == "rgbd":
                if self.depth_patch_embed_separate_params:
                    x = self.depth_patch_embed(x)
                else:
                    x = self.patch_embed(x)
            else:
                logging.info("Depth mode %s not supported" % self.depth_mode)
                raise NotImplementedError()
        else:
            x = self.patch_embed(x)
        return x