in src/diffusers/loaders/lora_conversion_utils.py [0:0]
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
def remap_norm_scale_shift_(key, state_dict):
weight = state_dict.pop(key)
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
def remap_txt_in_(key, state_dict):
def rename_key(key):
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
new_key = new_key.replace("txt_in", "context_embedder")
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
new_key = new_key.replace("mlp", "ff")
return new_key
if "self_attn_qkv" in key:
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
else:
state_dict[rename_key(key)] = state_dict.pop(key)
def remap_img_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
if "lora_A" in key:
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
else:
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
def remap_txt_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
if "lora_A" in key:
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
else:
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
def remap_single_transformer_blocks_(key, state_dict):
hidden_size = 3072
if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
linear1_weight = state_dict.pop(key)
if "lora_A" in key:
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_A.weight"
)
state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
else:
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_B.weight"
)
state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
linear1_bias = state_dict.pop(key)
if "lora_A" in key:
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_A.bias"
)
state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
else:
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_B.bias"
)
state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
else:
new_key = key.replace("single_blocks", "single_transformer_blocks")
new_key = new_key.replace("linear2", "proj_out")
new_key = new_key.replace("q_norm", "attn.norm_q")
new_key = new_key.replace("k_norm", "attn.norm_k")
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT = {
"img_in": "x_embedder",
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
"double_blocks": "transformer_blocks",
"img_attn_q_norm": "attn.norm_q",
"img_attn_k_norm": "attn.norm_k",
"img_attn_proj": "attn.to_out.0",
"txt_attn_q_norm": "attn.norm_added_q",
"txt_attn_k_norm": "attn.norm_added_k",
"txt_attn_proj": "attn.to_add_out",
"img_mod.linear": "norm1.linear",
"img_norm1": "norm1.norm",
"img_norm2": "norm2",
"img_mlp": "ff",
"txt_mod.linear": "norm1_context.linear",
"txt_norm1": "norm1.norm",
"txt_norm2": "norm2_context",
"txt_mlp": "ff_context",
"self_attn_proj": "attn.to_out.0",
"modulation.linear": "norm.linear",
"pre_norm": "norm.norm",
"final_layer.norm_final": "norm_out.norm",
"final_layer.linear": "proj_out",
"fc1": "net.0.proj",
"fc2": "net.2",
"input_embedder": "proj_in",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"txt_in": remap_txt_in_,
"img_attn_qkv": remap_img_attn_qkv_,
"txt_attn_qkv": remap_txt_attn_qkv_,
"single_blocks": remap_single_transformer_blocks_,
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
}
# Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
# and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
# sure that both follow the same initial format by stripping off the "transformer." prefix.
for key in list(converted_state_dict.keys()):
if key.startswith("transformer."):
converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
if key.startswith("diffusion_model."):
converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
# Rename and remap the state dict keys
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
# Add back the "transformer." prefix
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict