def extract_timm_shapes_from_config()

in optimum_benchmark/backends/timm_utils.py [0:0]


def extract_timm_shapes_from_config(config: "PretrainedConfig") -> Dict[str, Any]:
    if not is_timm_available():
        raise ImportError("timm is not available. Please, pip install timm.")

    artifacts_dict = {}

    config_dict = {k: v for k, v in config.to_dict().items() if v is not None}
    artifacts_dict.update(config_dict)

    shapes = {}

    # image input
    if "num_channels" in artifacts_dict:
        shapes["num_channels"] = artifacts_dict.get("num_channels", None)
    elif "channels" in artifacts_dict:
        shapes["num_channels"] = artifacts_dict.get("channels", None)

    if "image_size" in artifacts_dict:
        image_size = artifacts_dict["image_size"]
    elif "size" in artifacts_dict:
        image_size = artifacts_dict["size"]
    else:
        image_size = None

    if isinstance(image_size, (int, float)):
        shapes["height"] = image_size
        shapes["width"] = image_size
    elif isinstance(image_size, (list, tuple)):
        shapes["height"] = image_size[0]
        shapes["width"] = image_size[0]
    elif isinstance(image_size, dict) and len(image_size) == 2:
        shapes["height"] = list(image_size.values())[0]
        shapes["width"] = list(image_size.values())[1]
    elif isinstance(image_size, dict) and len(image_size) == 1:
        shapes["height"] = list(image_size.values())[0]
        shapes["width"] = list(image_size.values())[0]

    if "input_size" in artifacts_dict:
        input_size = artifacts_dict.get("input_size", None)
        shapes["num_channels"] = input_size[0]
        shapes["height"] = input_size[1]
        shapes["width"] = input_size[2]

    return shapes