in src/diffusers/loaders/single_file_utils.py [0:0]
def infer_diffusers_model_type(checkpoint):
if (
CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9
):
if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
model_type = "inpainting_v2"
elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
model_type = "xl_inpaint"
else:
model_type = "inpainting"
elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
model_type = "v2"
elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint:
model_type = "playground-v2-5"
elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
model_type = "xl_base"
elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
model_type = "xl_refiner"
elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
model_type = "upscale"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["controlnet"]):
if CHECKPOINT_KEY_NAMES["controlnet_xl"] in checkpoint:
if CHECKPOINT_KEY_NAMES["controlnet_xl_large"] in checkpoint:
model_type = "controlnet_xl_large"
elif CHECKPOINT_KEY_NAMES["controlnet_xl_mid"] in checkpoint:
model_type = "controlnet_xl_mid"
else:
model_type = "controlnet_xl_small"
else:
model_type = "controlnet"
elif (
CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536
):
model_type = "stable_cascade_stage_c_lite"
elif (
CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048
):
model_type = "stable_cascade_stage_c"
elif (
CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576
):
model_type = "stable_cascade_stage_b_lite"
elif (
CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640
):
model_type = "stable_cascade_stage_b"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
):
if "model.diffusion_model.pos_embed" in checkpoint:
key = "model.diffusion_model.pos_embed"
else:
key = "pos_embed"
if checkpoint[key].shape[1] == 36864:
model_type = "sd3"
elif checkpoint[key].shape[1] == 147456:
model_type = "sd35_medium"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
model_type = "sd35_large"
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
model_type = "animatediff_scribble"
elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint:
model_type = "animatediff_rgb"
elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
model_type = "animatediff_v2"
elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
model_type = "animatediff_sdxl_beta"
elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24:
model_type = "animatediff_v1"
else:
model_type = "animatediff_v3"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
):
if "model.diffusion_model.img_in.weight" in checkpoint:
key = "model.diffusion_model.img_in.weight"
else:
key = "img_in.weight"
if checkpoint[key].shape[1] == 384:
model_type = "flux-fill"
elif checkpoint[key].shape[1] == 128:
model_type = "flux-depth"
else:
model_type = "flux-dev"
else:
model_type = "flux-schnell"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
has_vae = "vae.encoder.conv_in.conv.bias" in checkpoint
if any(key.endswith("transformer_blocks.47.scale_shift_table") for key in checkpoint):
model_type = "ltx-video-0.9.7"
elif has_vae and checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
model_type = "ltx-video-0.9.5"
elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
model_type = "ltx-video-0.9.1"
else:
model_type = "ltx-video"
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
encoder_key = "encoder.project_in.conv.conv.bias"
decoder_key = "decoder.project_in.main.conv.weight"
if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint:
model_type = "autoencoder-dc-f32c32-sana"
elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32:
model_type = "autoencoder-dc-f32c32"
elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128:
model_type = "autoencoder-dc-f64c128"
else:
model_type = "autoencoder-dc-f128c512"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
model_type = "mochi-1-preview"
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
model_type = "hunyuan-video"
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]):
model_type = "auraflow"
elif (
CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
):
model_type = "instruct-pix2pix"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
model_type = "lumina2"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]):
model_type = "sana"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
if "model.diffusion_model.patch_embedding.weight" in checkpoint:
target_key = "model.diffusion_model.patch_embedding.weight"
else:
target_key = "patch_embedding.weight"
if checkpoint[target_key].shape[0] == 1536:
model_type = "wan-t2v-1.3B"
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
model_type = "wan-t2v-14B"
else:
model_type = "wan-i2v-14B"
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
model_type = "wan-t2v-14B"
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
model_type = "hidream"
else:
model_type = "v1"
return model_type