def __init__()

in muse/modeling_transformer_v2.py [0:0]


    def __init__(self, **kwargs):
        super().__init__()

        config = config_from_legacy_kwargs(**kwargs)
        self.register_to_config(**dataclasses.asdict(config))
        self.register_to_config(mask_token_id=self.config.vocab_size - 1)

        # TODO: Allow enabling fused norm using a function (like we do for xformers attention)
        if self.config.use_fused_residual_norm and dropout_add_layer_norm is None:
            warnings.warn("Cannot use fused layer norm. Please install flash_attn. Falling back to unfused layer norm", UserWarning)
            self.register_to_config(use_fused_residual_norm=False)

        assert len(self.config.block_out_channels) == 1

        # Legacy: kept for compatibility with pipeline
        self.output_size = self.config.codebook_size

        self.encoder_proj = nn.Linear(
            self.config.encoder_hidden_size, self.config.hidden_size, bias=self.config.use_bias
        )
        self.encoder_proj_layer_norm = Norm(self.config.hidden_size, self.config)

        self.embed = ConvEmbed(self.config)

        self.cond_embed = nn.Sequential(
            nn.Linear(
                self.config.micro_cond_embed_dim + self.config.cond_embed_dim,
                self.config.hidden_size,
                bias=self.config.use_bias,
            ),
            nn.SiLU(),
            nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=self.config.use_bias),
        )

        self.down_blocks = nn.ModuleList([DownsampleBlock(self.config.block_out_channels[0], self.config)])

        self.project_to_hidden_norm = Norm(self.config.block_out_channels[-1], self.config)
        self.project_to_hidden = nn.Linear(
            self.config.block_out_channels[-1], self.config.hidden_size, bias=self.config.use_bias
        )

        self.transformer_layers = nn.ModuleList(
            [TransformerLayer(self.config) for _ in range(self.config.num_hidden_layers)]
        )

        self.project_from_hidden_norm = Norm(self.config.hidden_size, self.config)
        self.project_from_hidden = nn.Linear(
            self.config.hidden_size, self.config.block_out_channels[-1], bias=self.config.use_bias
        )

        self.up_blocks = nn.ModuleList([UpsampleBlock(self.config.block_out_channels[0], self.config)])

        self.mlm_layer = ConvMlmLayer(self.config)

        self.gradient_checkpointing = False

        # --- WEIGHT INIT ---
        self.apply(self._init_weights)  # General init
        nn.init.xavier_uniform_(self.embed.conv.weight, 0.02)  # inputs
        nn.init.normal_(self.embed.embeddings.weight, std=np.sqrt(1 / self.config.vocab_size))
        nn.init.constant_(self.mlm_layer.conv1.weight, 0)  # output
        self.mlm_layer.conv2.weight.data = self.embed.embeddings.weight.data[
            : self.config.codebook_size, :, None, None
        ].clone()

        # init AdaLNModulation.mapper layers to 0
        for m in self.modules():
            if isinstance(m, AdaLNModulation):
                nn.init.constant_(m.mapper.weight, 0)
                if hasattr(m, "bias") and m.bias is not None:
                    nn.init.constant_(m.bias, 0)