in src/diffusers/loaders/lora_conversion_utils.py [0:0]
def _convert_kohya_flux_lora_to_diffusers(state_dict):
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
# scale weight by alpha and dim
rank = down_weight.shape[0]
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() # alpha is scalar
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
sd_lora_rank = down_weight.shape[0]
# scale weight by alpha and dim
default_alpha = torch.tensor(
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
scale = alpha / sd_lora_rank
# calculate scale_down and scale_up
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up
# calculate dims if not provided
num_splits = len(ait_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]
# check upweight is sparse or not
is_sparse = False
if sd_lora_rank % num_splits == 0:
ait_rank = sd_lora_rank // num_splits
is_sparse = True
i = 0
for j in range(len(dims)):
for k in range(len(dims)):
if j == k:
continue
is_sparse = is_sparse and torch.all(
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
)
i += dims[j]
if is_sparse:
logger.info(f"weight is sparse: {sds_key}")
# make ai-toolkit weight
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
if not is_sparse:
# down_weight is copied to each split
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
else:
# down_weight is chunked to each split
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
# up_weight is sparse: only non-zero values are copied to each split
i = 0
for j in range(len(dims)):
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
i += dims[j]
def _convert_sd_scripts_to_ai_toolkit(sds_sd):
ait_sd = {}
for i in range(19):
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_out.0",
)
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_0",
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_2",
f"transformer.transformer_blocks.{i}.ff.net.2",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mod_lin",
f"transformer.transformer_blocks.{i}.norm1.linear",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_add_out",
)
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_0",
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_2",
f"transformer.transformer_blocks.{i}.ff_context.net.2",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mod_lin",
f"transformer.transformer_blocks.{i}.norm1_context.linear",
)
for i in range(38):
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear1",
[
f"transformer.single_transformer_blocks.{i}.attn.to_q",
f"transformer.single_transformer_blocks.{i}.attn.to_k",
f"transformer.single_transformer_blocks.{i}.attn.to_v",
f"transformer.single_transformer_blocks.{i}.proj_mlp",
],
dims=[3072, 3072, 3072, 12288],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear2",
f"transformer.single_transformer_blocks.{i}.proj_out",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_modulation_lin",
f"transformer.single_transformer_blocks.{i}.norm.linear",
)
# TODO: alphas.
def assign_remaining_weights(assignments, source):
for lora_key in ["lora_A", "lora_B"]:
orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
for target_fmt, source_fmt, transform in assignments:
target_key = target_fmt.format(lora_key=lora_key)
source_key = source_fmt.format(orig_lora_key=orig_lora_key)
value = source.pop(source_key)
if transform:
value = transform(value)
ait_sd[target_key] = value
if any("guidance_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
None,
),
],
sds_sd,
)
if any("img_in" in k for k in sds_sd):
assign_remaining_weights(
[
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
],
sds_sd,
)
if any("txt_in" in k for k in sds_sd):
assign_remaining_weights(
[
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
],
sds_sd,
)
if any("time_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
None,
),
],
sds_sd,
)
if any("vector_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
None,
),
],
sds_sd,
)
if any("final_layer" in k for k in sds_sd):
# Notice the swap in processing for "final_layer".
assign_remaining_weights(
[
(
"norm_out.linear.{lora_key}.weight",
"lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight",
swap_scale_shift,
),
("proj_out.{lora_key}.weight", "lora_unet_final_layer_linear.{orig_lora_key}.weight", None),
],
sds_sd,
)
remaining_keys = list(sds_sd.keys())
te_state_dict = {}
if remaining_keys:
if not all(k.startswith(("lora_te", "lora_te1")) for k in remaining_keys):
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
for key in remaining_keys:
if not key.endswith("lora_down.weight"):
continue
lora_name = key.split(".")[0]
lora_name_up = f"{lora_name}.lora_up.weight"
lora_name_alpha = f"{lora_name}.alpha"
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
if lora_name.startswith(("lora_te_", "lora_te1_")):
down_weight = sds_sd.pop(key)
sd_lora_rank = down_weight.shape[0]
te_state_dict[diffusers_name] = down_weight
te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
if lora_name_alpha in sds_sd:
alpha = sds_sd.pop(lora_name_alpha).item()
scale = alpha / sd_lora_rank
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
te_state_dict[diffusers_name] *= scale_down
te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
if len(sds_sd) > 0:
logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
if te_state_dict:
te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
new_state_dict = {**ait_sd, **te_state_dict}
return new_state_dict
def _convert_mixture_state_dict_to_diffusers(state_dict):
new_state_dict = {}
def _convert(original_key, diffusers_key, state_dict, new_state_dict):
down_key = f"{original_key}.lora_down.weight"
down_weight = state_dict.pop(down_key)
lora_rank = down_weight.shape[0]
up_weight_key = f"{original_key}.lora_up.weight"
up_weight = state_dict.pop(up_weight_key)
alpha_key = f"{original_key}.alpha"
alpha = state_dict.pop(alpha_key)
# scale weight by alpha and dim
scale = alpha / lora_rank
# calculate scale_down and scale_up
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up
diffusers_down_key = f"{diffusers_key}.lora_A.weight"
new_state_dict[diffusers_down_key] = down_weight
new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight
all_unique_keys = {
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "")
for k in state_dict
if not k.startswith(("lora_unet_"))
}
assert all(k.startswith(("lora_transformer_", "lora_te1_")) for k in all_unique_keys), f"{all_unique_keys=}"
has_te_keys = False
for k in all_unique_keys:
if k.startswith("lora_transformer_single_transformer_blocks_"):
i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
diffusers_key = f"single_transformer_blocks.{i}"
elif k.startswith("lora_transformer_transformer_blocks_"):
i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
diffusers_key = f"transformer_blocks.{i}"
elif k.startswith("lora_te1_"):
has_te_keys = True
continue
elif k.startswith("lora_transformer_context_embedder"):
diffusers_key = "context_embedder"
elif k.startswith("lora_transformer_norm_out_linear"):
diffusers_key = "norm_out.linear"
elif k.startswith("lora_transformer_proj_out"):
diffusers_key = "proj_out"
elif k.startswith("lora_transformer_x_embedder"):
diffusers_key = "x_embedder"
elif k.startswith("lora_transformer_time_text_embed_guidance_embedder_linear_"):
i = int(k.split("lora_transformer_time_text_embed_guidance_embedder_linear_")[-1])
diffusers_key = f"time_text_embed.guidance_embedder.linear_{i}"
elif k.startswith("lora_transformer_time_text_embed_text_embedder_linear_"):
i = int(k.split("lora_transformer_time_text_embed_text_embedder_linear_")[-1])
diffusers_key = f"time_text_embed.text_embedder.linear_{i}"
elif k.startswith("lora_transformer_time_text_embed_timestep_embedder_linear_"):
i = int(k.split("lora_transformer_time_text_embed_timestep_embedder_linear_")[-1])
diffusers_key = f"time_text_embed.timestep_embedder.linear_{i}"
else:
raise NotImplementedError(f"Handling for key ({k}) is not implemented.")
if "attn_" in k:
if "_to_out_0" in k:
diffusers_key += ".attn.to_out.0"
elif "_to_add_out" in k:
diffusers_key += ".attn.to_add_out"
elif any(qkv in k for qkv in ["to_q", "to_k", "to_v"]):
remaining = k.split("attn_")[-1]
diffusers_key += f".attn.{remaining}"
elif any(add_qkv in k for add_qkv in ["add_q_proj", "add_k_proj", "add_v_proj"]):
remaining = k.split("attn_")[-1]
diffusers_key += f".attn.{remaining}"
_convert(k, diffusers_key, state_dict, new_state_dict)
if has_te_keys:
layer_pattern = re.compile(r"lora_te1_text_model_encoder_layers_(\d+)")
attn_mapping = {
"q_proj": ".self_attn.q_proj",
"k_proj": ".self_attn.k_proj",
"v_proj": ".self_attn.v_proj",
"out_proj": ".self_attn.out_proj",
}
mlp_mapping = {"fc1": ".mlp.fc1", "fc2": ".mlp.fc2"}
for k in all_unique_keys:
if not k.startswith("lora_te1_"):
continue
match = layer_pattern.search(k)
if not match:
continue
i = int(match.group(1))
diffusers_key = f"text_model.encoder.layers.{i}"
if "attn" in k:
for key_fragment, suffix in attn_mapping.items():
if key_fragment in k:
diffusers_key += suffix
break
elif "mlp" in k:
for key_fragment, suffix in mlp_mapping.items():
if key_fragment in k:
diffusers_key += suffix
break
_convert(k, diffusers_key, state_dict, new_state_dict)
remaining_all_unet = False
if state_dict:
remaining_all_unet = all(k.startswith("lora_unet_") for k in state_dict)
if remaining_all_unet:
keys = list(state_dict.keys())
for k in keys:
state_dict.pop(k)
if len(state_dict) > 0:
raise ValueError(
f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
)
transformer_state_dict = {
f"transformer.{k}": v for k, v in new_state_dict.items() if not k.startswith("text_model.")
}
te_state_dict = {f"text_encoder.{k}": v for k, v in new_state_dict.items() if k.startswith("text_model.")}
return {**transformer_state_dict, **te_state_dict}
# This is weird.
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
# has both `peft` and non-peft state dict.
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
if has_peft_state_dict:
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
return state_dict
# Another weird one.
has_mixture = any(
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
)
# ComfyUI.
if not has_mixture:
state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()}
state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()}
has_position_embedding = any("position_embedding" in k for k in state_dict)
if has_position_embedding:
zero_status_pe = state_dict_all_zero(state_dict, "position_embedding")
if zero_status_pe:
logger.info(
"The `position_embedding` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the current state dict to make loading possible."
)
else:
logger.info(
"The state_dict has position_embedding LoRA params and we currently do not support them. "
"Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if "position_embedding" not in k}
has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
if has_t5xxl:
zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl")
if zero_status_t5:
logger.info(
"The `t5xxl` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the current state dict to make loading possible."
)
else:
logger.info(
"T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
"Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
if has_diffb:
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
if zero_status_diff_b:
logger.info(
"The `diff_b` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the current state dict to make loading possible."
)
else:
logger.info(
"`diff_b` keys found in the state dict which are currently unsupported. "
"So, we will filter out those keys. Open an issue if this is a problem - "
"https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if ".diff_b" not in k}
has_norm_diff = any(".norm" in k and ".diff" in k for k in state_dict)
if has_norm_diff:
zero_status_diff = state_dict_all_zero(state_dict, ".diff")
if zero_status_diff:
logger.info(
"The `diff` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the current state dict to make loading possible."
)
else:
logger.info(
"Normalization diff keys found in the state dict which are currently unsupported. "
"So, we will filter out those keys. Open an issue if this is a problem - "
"https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if ".norm" not in k and ".diff" not in k}
limit_substrings = ["lora_down", "lora_up"]
if any("alpha" in k for k in state_dict):
limit_substrings.append("alpha")
state_dict = {
_custom_replace(k, limit_substrings): v
for k, v in state_dict.items()
if k.startswith(("lora_unet_", "lora_te_"))
}
if any("text_projection" in k for k in state_dict):
logger.info(
"`text_projection` keys found in the `state_dict` which are unexpected. "
"So, we will filter out those keys. Open an issue if this is a problem - "
"https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if "text_projection" not in k}
if has_mixture:
return _convert_mixture_state_dict_to_diffusers(state_dict)
return _convert_sd_scripts_to_ai_toolkit(state_dict)