in muse/modeling_transformer_v2.py [0:0]
def __init__(self, **kwargs):
super().__init__()
config = config_from_legacy_kwargs(**kwargs)
self.register_to_config(**dataclasses.asdict(config))
self.register_to_config(mask_token_id=self.config.vocab_size - 1)
# TODO: Allow enabling fused norm using a function (like we do for xformers attention)
if self.config.use_fused_residual_norm and dropout_add_layer_norm is None:
warnings.warn("Cannot use fused layer norm. Please install flash_attn. Falling back to unfused layer norm", UserWarning)
self.register_to_config(use_fused_residual_norm=False)
assert len(self.config.block_out_channels) == 1
# Legacy: kept for compatibility with pipeline
self.output_size = self.config.codebook_size
self.encoder_proj = nn.Linear(
self.config.encoder_hidden_size, self.config.hidden_size, bias=self.config.use_bias
)
self.encoder_proj_layer_norm = Norm(self.config.hidden_size, self.config)
self.embed = ConvEmbed(self.config)
self.cond_embed = nn.Sequential(
nn.Linear(
self.config.micro_cond_embed_dim + self.config.cond_embed_dim,
self.config.hidden_size,
bias=self.config.use_bias,
),
nn.SiLU(),
nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=self.config.use_bias),
)
self.down_blocks = nn.ModuleList([DownsampleBlock(self.config.block_out_channels[0], self.config)])
self.project_to_hidden_norm = Norm(self.config.block_out_channels[-1], self.config)
self.project_to_hidden = nn.Linear(
self.config.block_out_channels[-1], self.config.hidden_size, bias=self.config.use_bias
)
self.transformer_layers = nn.ModuleList(
[TransformerLayer(self.config) for _ in range(self.config.num_hidden_layers)]
)
self.project_from_hidden_norm = Norm(self.config.hidden_size, self.config)
self.project_from_hidden = nn.Linear(
self.config.hidden_size, self.config.block_out_channels[-1], bias=self.config.use_bias
)
self.up_blocks = nn.ModuleList([UpsampleBlock(self.config.block_out_channels[0], self.config)])
self.mlm_layer = ConvMlmLayer(self.config)
self.gradient_checkpointing = False
# --- WEIGHT INIT ---
self.apply(self._init_weights) # General init
nn.init.xavier_uniform_(self.embed.conv.weight, 0.02) # inputs
nn.init.normal_(self.embed.embeddings.weight, std=np.sqrt(1 / self.config.vocab_size))
nn.init.constant_(self.mlm_layer.conv1.weight, 0) # output
self.mlm_layer.conv2.weight.data = self.embed.embeddings.weight.data[
: self.config.codebook_size, :, None, None
].clone()
# init AdaLNModulation.mapper layers to 0
for m in self.modules():
if isinstance(m, AdaLNModulation):
nn.init.constant_(m.mapper.weight, 0)
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias, 0)