def get_ae_config()

in scripts/convert_dcae_to_diffusers.py [0:0]


def get_ae_config(name: str):
    if name in ["dc-ae-f32c32-sana-1.0"]:
        config = {
            "latent_channels": 32,
            "encoder_block_types": (
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ),
            "decoder_block_types": (
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ),
            "encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
            "decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
            "encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
            "decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
            "encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
            "decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
            "downsample_block_type": "conv",
            "upsample_block_type": "interpolate",
            "decoder_norm_types": "rms_norm",
            "decoder_act_fns": "silu",
            "scaling_factor": 0.41407,
        }
    elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
        AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS)
        config = {
            "latent_channels": 32,
            "encoder_block_types": [
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ],
            "decoder_block_types": [
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ],
            "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
            "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
            "encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
            "decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
            "encoder_qkv_multiscales": ((), (), (), (), (), ()),
            "decoder_qkv_multiscales": ((), (), (), (), (), ()),
            "decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
            "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
        }
        if name == "dc-ae-f32c32-in-1.0":
            config["scaling_factor"] = 0.3189
        elif name == "dc-ae-f32c32-mix-1.0":
            config["scaling_factor"] = 0.4552
    elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
        AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS)
        config = {
            "latent_channels": 128,
            "encoder_block_types": [
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ],
            "decoder_block_types": [
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ],
            "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
            "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
            "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
            "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
            "encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
            "decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
            "decoder_norm_types": [
                "batch_norm",
                "batch_norm",
                "batch_norm",
                "rms_norm",
                "rms_norm",
                "rms_norm",
                "rms_norm",
            ],
            "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
        }
        if name == "dc-ae-f64c128-in-1.0":
            config["scaling_factor"] = 0.2889
        elif name == "dc-ae-f64c128-mix-1.0":
            config["scaling_factor"] = 0.4538
    elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
        AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS)
        config = {
            "latent_channels": 512,
            "encoder_block_types": [
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ],
            "decoder_block_types": [
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ],
            "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
            "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
            "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
            "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
            "encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
            "decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
            "decoder_norm_types": [
                "batch_norm",
                "batch_norm",
                "batch_norm",
                "rms_norm",
                "rms_norm",
                "rms_norm",
                "rms_norm",
                "rms_norm",
            ],
            "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
        }
        if name == "dc-ae-f128c512-in-1.0":
            config["scaling_factor"] = 0.4883
        elif name == "dc-ae-f128c512-mix-1.0":
            config["scaling_factor"] = 0.3620
    else:
        raise ValueError("Invalid config name provided.")

    return config