def config_from_legacy_kwargs()

in muse/modeling_transformer_v2.py [0:0]


def config_from_legacy_kwargs(**kwargs):
    if "block_num_heads" in kwargs:
        if isinstance(kwargs["block_num_heads"], (tuple, list)):
            assert len(kwargs["block_num_heads"]) == 1
            kwargs["block_num_heads"] = kwargs["block_num_heads"][0]
        elif isinstance(kwargs["block_num_heads"], int):
            ...
        else:
            assert False

    config = {}

    # select only values that are expected to be in the config
    for field in dataclasses.fields(MaskGiTUViT_v2Config):
        if field.name in kwargs:
            config[field.name] = kwargs[field.name]

    # set default config values
    config = MaskGiTUViT_v2Config(**config)
    config.block_out_channels = list(config.block_out_channels)

    return config