in src/diffusers/loaders/single_file_utils.py [0:0]
def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
num_guidance_layers = (
list(set(int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k))[-1] + 1 # noqa: C401
)
mlp_ratio = 4.0
inner_dim = 3072
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
# guidance
converted_state_dict["distilled_guidance_layer.in_proj.bias"] = checkpoint.pop(
"distilled_guidance_layer.in_proj.bias"
)
converted_state_dict["distilled_guidance_layer.in_proj.weight"] = checkpoint.pop(
"distilled_guidance_layer.in_proj.weight"
)
converted_state_dict["distilled_guidance_layer.out_proj.bias"] = checkpoint.pop(
"distilled_guidance_layer.out_proj.bias"
)
converted_state_dict["distilled_guidance_layer.out_proj.weight"] = checkpoint.pop(
"distilled_guidance_layer.out_proj.weight"
)
for i in range(num_guidance_layers):
block_prefix = f"distilled_guidance_layer.layers.{i}."
converted_state_dict[f"{block_prefix}linear_1.bias"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.in_layer.bias"
)
converted_state_dict[f"{block_prefix}linear_1.weight"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.in_layer.weight"
)
converted_state_dict[f"{block_prefix}linear_2.bias"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.out_layer.bias"
)
converted_state_dict[f"{block_prefix}linear_2.weight"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.out_layer.weight"
)
converted_state_dict[f"distilled_guidance_layer.norms.{i}.weight"] = checkpoint.pop(
f"distilled_guidance_layer.norms.{i}.scale"
)
# context_embedder
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
# x_embedder
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# Q, K, V
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
context_q, context_k, context_v = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
)
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
)
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
# qk_norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
)
# ff img_mlp
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.bias"
)
# output projections.
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.bias"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.bias"
)
# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
q_bias, k_bias, v_bias, mlp_bias = torch.split(
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
# qk norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.key_norm.scale"
)
# output projections.
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
return converted_state_dict