def __init__()

in janus/janusflow/models/uvit.py [0:0]


    def __init__(self, decoder_cfg, gpt_cfg, layer_id=None):
        super().__init__()
        self.layer_id = layer_id
        cfg = (
            AttrDict(
                norm_type="layernorm",
                is_exp_norm=False,
                sequence_parallel=False,
                use_userbuffer=False,
                norm_eps=1e-5,
                norm_bias=True,
                gradient_accumulation_fusion=True,
                use_fp32_head_weight=False,
            )
            + gpt_cfg
        )
        group = PG.tensor_parallel_group()
        assert cfg.norm_type in [
            "layernorm",
            "rmsnorm",
        ], f"Norm type:{cfg.norm_type} not supported"
        if cfg.norm_type == "rmsnorm":
            self.norm = DropoutAddRMSNorm(
                cfg.n_embed,
                prenorm=False,
                eps=cfg.norm_eps,
                is_exp_norm=cfg.is_exp_norm,
                sequence_parallel=cfg.sequence_parallel,
            )
        else:
            self.norm = DropoutAddLayerNorm(
                cfg.n_embed,
                prenorm=False,
                eps=cfg.norm_eps,
                is_exp_norm=cfg.is_exp_norm,
                sequence_parallel=cfg.sequence_parallel,
                bias=cfg.norm_bias,
            )

        multiple_of = 256
        if decoder_cfg.in_channels % multiple_of != 0:
            warnings.warn(
                f"建议把 vocab_size 设置为 {multiple_of} 的倍数, 否则会影响矩阵乘法的性能"
            )

        dtype = default_dtype = torch.get_default_dtype()
        if cfg.use_fp32_head_weight:
            dtype = torch.float32
            print(
                "使用 fp32 head weight!!!! 与原来的 bf16 head weight 不兼容\n",
                end="",
                flush=True,
            )
        torch.set_default_dtype(dtype)
        self.head = ColumnParallelLinear(
            cfg.n_embed,
            decoder_cfg.in_channels,
            bias=True,
            group=group,
            sequence_parallel=cfg.sequence_parallel,
            use_userbuffer=cfg.use_userbuffer,
            gradient_accumulation_fusion=cfg.gradient_accumulation_fusion,
            use_fp32_output=False,
        )
        torch.set_default_dtype(default_dtype)

        self.use_fp32_head_weight = cfg.use_fp32_head_weight