def forward_features()

in slowfast/models/video_model_builder.py [0:0]


    def forward_features(self, x):
        if self.video_input:
            x = x[0]
        B = x.shape[0]

        # Tokenize input
        if self.cfg.VIT.PATCH_SIZE_TEMP > 1:
            x = self.patch_embed_3d(x)
        else:
            # 2D tokenization
            if self.video_input:
                x = x.permute(0, 2, 1, 3, 4)
                (B, T, C, H, W) = x.shape
                x = x.reshape(B*T, C, H, W)

            x = self.patch_embed(x)

            if self.video_input:
                (B2, T2, D2) = x.shape
                x = x.reshape(B, T*T2, D2)

        # Append CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # Interpolate positinoal embeddings
        if self.cfg.DATA.TRAIN_CROP_SIZE != 224:
            pos_embed = self.pos_embed
            N = pos_embed.shape[1] - 1
            npatch = int((x.size(1) - 1) / self.temporal_resolution)
            class_emb = pos_embed[:, 0]
            pos_embed = pos_embed[:, 1:]
            dim = x.shape[-1]
            pos_embed = torch.nn.functional.interpolate(
                pos_embed.reshape(
                    1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
                    0, 3, 1, 2),
                scale_factor=math.sqrt(npatch / N),
                mode='bicubic',
            )
            pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
            new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
        else:
            new_pos_embed = self.pos_embed
            npatch = self.patch_embed.num_patches

        # Add positional embeddings to input
        if self.video_input:
            if self.cfg.VIT.POS_EMBED == "separate":
                cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
                tile_pos_embed = new_pos_embed[:, 1:, :].repeat(
                    1, self.temporal_resolution, 1)
                tile_temporal_embed = self.temp_embed.repeat_interleave(
                    npatch, 1)
                total_pos_embed = tile_pos_embed + tile_temporal_embed
                total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1)
                x = x + total_pos_embed
            elif self.cfg.VIT.POS_EMBED == "joint":
                x = x + self.st_embed
        else:
            # image input
            x = x + new_pos_embed
                            
        # Apply positional dropout
        x = self.pos_drop(x)

        # Encoding using transformer layers
        for i, blk in enumerate(self.blocks):
            x = blk(
                x,
                seq_len=npatch,
                num_frames=self.temporal_resolution,
                approx=self.cfg.VIT.APPROX_ATTN_TYPE,
                num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM
            )

        x = self.norm(x)[:, 0]
        x = self.pre_logits(x)
        return x