def __init__()

in slowfast/models/video_model_builder.py [0:0]


    def __init__(self, cfg):
        super().__init__()
        self.img_size = cfg.DATA.TRAIN_CROP_SIZE
        self.patch_size = cfg.VIT.PATCH_SIZE
        self.in_chans = cfg.VIT.CHANNELS
        if cfg.TRAIN.DATASET == "Epickitchens":
            self.num_classes = [97, 300]  
        else:
            self.num_classes = cfg.MODEL.NUM_CLASSES
        self.embed_dim = cfg.VIT.EMBED_DIM
        self.depth = cfg.VIT.DEPTH
        self.num_heads = cfg.VIT.NUM_HEADS
        self.mlp_ratio = cfg.VIT.MLP_RATIO
        self.qkv_bias = cfg.VIT.QKV_BIAS
        self.drop_rate = cfg.VIT.DROP
        self.drop_path_rate = cfg.VIT.DROP_PATH
        self.head_dropout = cfg.VIT.HEAD_DROPOUT
        self.video_input = cfg.VIT.VIDEO_INPUT
        self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION
        self.use_mlp = cfg.VIT.USE_MLP
        self.num_features = self.embed_dim
        norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT
        self.head_act = cfg.VIT.HEAD_ACT
        self.cfg = cfg

        # Patch Embedding
        self.patch_embed = vit_helper.PatchEmbed(
            img_size=224, 
            patch_size=self.patch_size, 
            in_chans=self.in_chans, 
            embed_dim=self.embed_dim
        )

        # 3D Patch Embedding
        self.patch_embed_3d = vit_helper.PatchEmbed3D(
            img_size=self.img_size, 
            temporal_resolution=self.temporal_resolution, 
            patch_size=self.patch_size,
            in_chans=self.in_chans, 
            embed_dim=self.embed_dim, 
            z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP
        )
        self.patch_embed_3d.proj.weight.data = torch.zeros_like(
            self.patch_embed_3d.proj.weight.data)
        
        # Number of patches
        if self.video_input:
            num_patches = self.patch_embed.num_patches * self.temporal_resolution
        else:
            num_patches = self.patch_embed.num_patches
        self.num_patches = num_patches

        # CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        trunc_normal_(self.cls_token, std=.02)
        
        # Positional embedding
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim))
        self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT)
        trunc_normal_(self.pos_embed, std=.02)

        if self.cfg.VIT.POS_EMBED == "joint":
            self.st_embed = nn.Parameter(
                torch.zeros(1, num_patches + 1, self.embed_dim))
            trunc_normal_(self.st_embed, std=.02)
        elif self.cfg.VIT.POS_EMBED == "separate":
            self.temp_embed = nn.Parameter(
                torch.zeros(1, self.temporal_resolution, self.embed_dim))

        # Layer Blocks
        dpr = [x.item() for x in torch.linspace(
            0, self.drop_path_rate, self.depth)]
        if self.cfg.VIT.ATTN_LAYER == "divided":
            self.blocks = nn.ModuleList([
                vit_helper.DividedSpaceTimeBlock(
                    attn_type=cfg.VIT.ATTN_LAYER, 
                    dim=self.embed_dim, 
                    num_heads=self.num_heads,
                    mlp_ratio=self.mlp_ratio, 
                    qkv_bias=self.qkv_bias, 
                    drop=self.drop_rate, 
                    attn_drop=self.attn_drop_rate, 
                    drop_path=dpr[i], 
                    norm_layer=norm_layer, 
                )
                for i in range(self.depth)
            ])
        else:
            self.blocks = nn.ModuleList([
                vit_helper.Block(
                    attn_type=cfg.VIT.ATTN_LAYER, 
                    dim=self.embed_dim, 
                    num_heads=self.num_heads,
                    mlp_ratio=self.mlp_ratio, 
                    qkv_bias=self.qkv_bias, 
                    drop=self.drop_rate, 
                    attn_drop=self.attn_drop_rate, 
                    drop_path=dpr[i], 
                    norm_layer=norm_layer,
                    use_original_code=self.cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE
                )
                for i in range(self.depth)
            ])
        self.norm = norm_layer(self.embed_dim)

        # MLP head
        if self.use_mlp:
            hidden_dim = self.embed_dim
            if self.head_act == 'tanh':
                print("Using TanH activation in MLP")
                act = nn.Tanh() 
            elif self.head_act == 'gelu':
                print("Using GELU activation in MLP")
                act = nn.GELU()
            else:
                print("Using ReLU activation in MLP")
                act = nn.ReLU()
            self.pre_logits = nn.Sequential(OrderedDict([
                ('fc', nn.Linear(self.embed_dim, hidden_dim)),
                ('act', act),
            ]))
        else:
            self.pre_logits = nn.Identity()
        
        # Classifier Head
        self.head_drop = nn.Dropout(p=self.head_dropout)
        if isinstance(self.num_classes, (list,)) and len(self.num_classes) > 1:
            for a, i in enumerate(range(len(self.num_classes))):
                setattr(self, "head%d"%a, nn.Linear(self.embed_dim, self.num_classes[i]))
        else:
            self.head = (nn.Linear(self.embed_dim, self.num_classes) 
                if self.num_classes > 0 else nn.Identity())

        # Initialize weights
        self.apply(self._init_weights)