in janus/janusflow/models/uvit.py [0:0]
def __init__(self, decoder_cfg, gpt_cfg, layer_id=None):
super().__init__()
self.layer_id = layer_id
cfg = (
AttrDict(
norm_type="layernorm",
is_exp_norm=False,
sequence_parallel=False,
use_userbuffer=False,
norm_eps=1e-5,
norm_bias=True,
gradient_accumulation_fusion=True,
use_fp32_head_weight=False,
)
+ gpt_cfg
)
group = PG.tensor_parallel_group()
assert cfg.norm_type in [
"layernorm",
"rmsnorm",
], f"Norm type:{cfg.norm_type} not supported"
if cfg.norm_type == "rmsnorm":
self.norm = DropoutAddRMSNorm(
cfg.n_embed,
prenorm=False,
eps=cfg.norm_eps,
is_exp_norm=cfg.is_exp_norm,
sequence_parallel=cfg.sequence_parallel,
)
else:
self.norm = DropoutAddLayerNorm(
cfg.n_embed,
prenorm=False,
eps=cfg.norm_eps,
is_exp_norm=cfg.is_exp_norm,
sequence_parallel=cfg.sequence_parallel,
bias=cfg.norm_bias,
)
multiple_of = 256
if decoder_cfg.in_channels % multiple_of != 0:
warnings.warn(
f"建议把 vocab_size 设置为 {multiple_of} 的倍数, 否则会影响矩阵乘法的性能"
)
dtype = default_dtype = torch.get_default_dtype()
if cfg.use_fp32_head_weight:
dtype = torch.float32
print(
"使用 fp32 head weight!!!! 与原来的 bf16 head weight 不兼容\n",
end="",
flush=True,
)
torch.set_default_dtype(dtype)
self.head = ColumnParallelLinear(
cfg.n_embed,
decoder_cfg.in_channels,
bias=True,
group=group,
sequence_parallel=cfg.sequence_parallel,
use_userbuffer=cfg.use_userbuffer,
gradient_accumulation_fusion=cfg.gradient_accumulation_fusion,
use_fp32_output=False,
)
torch.set_default_dtype(default_dtype)
self.use_fp32_head_weight = cfg.use_fp32_head_weight