def get_input_shapes()

in src/sagemaker_huggingface_inference_toolkit/optimum_utils.py [0:0]


def get_input_shapes(model_dir):
    """Method to get input shapes from model config file. If config file is not present, default values are returned."""
    from transformers import AutoConfig

    input_shapes = {}
    input_shapes_available = False
    # try to get input shapes from config file
    try:
        config = AutoConfig.from_pretrained(model_dir)
        if hasattr(config, "neuron"):
            # check if static batch size and sequence length are available
            if config.neuron.get("static_batch_size", None) and config.neuron.get("static_sequence_length", None):
                input_shapes["batch_size"] = config.neuron["static_batch_size"]
                input_shapes["sequence_length"] = config.neuron["static_sequence_length"]
                input_shapes_available = True
                logger.info(
                    f"Input shapes found in config file. Using input shapes from config with batch size {input_shapes['batch_size']} and sequence length {input_shapes['sequence_length']}"
                )
            else:
                # Add warning if environment variables are set but will be ignored
                if os.environ.get("HF_OPTIMUM_BATCH_SIZE", None) is not None:
                    logger.warning(
                        "HF_OPTIMUM_BATCH_SIZE environment variable is set. Environment variable will be ignored and input shapes from config file will be used."
                    )
                if os.environ.get("HF_OPTIMUM_SEQUENCE_LENGTH", None) is not None:
                    logger.warning(
                        "HF_OPTIMUM_SEQUENCE_LENGTH environment variable is set. Environment variable will be ignored and input shapes from config file will be used."
                    )
    except Exception:
        input_shapes_available = False

    # return input shapes if available
    if input_shapes_available:
        return input_shapes

    # extract input shapes from environment variables
    sequence_length = os.environ.get("HF_OPTIMUM_SEQUENCE_LENGTH", None)
    if sequence_length is None:
        raise ValueError(
            "HF_OPTIMUM_SEQUENCE_LENGTH environment variable is not set. Please set HF_OPTIMUM_SEQUENCE_LENGTH to a positive integer."
        )

    if not int(sequence_length) > 0:
        raise ValueError(
            f"HF_OPTIMUM_SEQUENCE_LENGTH must be set to a positive integer. Current value is {sequence_length}"
        )
    batch_size = os.environ.get("HF_OPTIMUM_BATCH_SIZE", 1)
    logger.info(
        f"Using input shapes from environment variables with batch size {batch_size} and sequence length {sequence_length}"
    )
    return {"batch_size": int(batch_size), "sequence_length": int(sequence_length)}