in src/sagemaker_huggingface_inference_toolkit/optimum_utils.py [0:0]
def get_input_shapes(model_dir):
"""Method to get input shapes from model config file. If config file is not present, default values are returned."""
from transformers import AutoConfig
input_shapes = {}
input_shapes_available = False
# try to get input shapes from config file
try:
config = AutoConfig.from_pretrained(model_dir)
if hasattr(config, "neuron"):
# check if static batch size and sequence length are available
if config.neuron.get("static_batch_size", None) and config.neuron.get("static_sequence_length", None):
input_shapes["batch_size"] = config.neuron["static_batch_size"]
input_shapes["sequence_length"] = config.neuron["static_sequence_length"]
input_shapes_available = True
logger.info(
f"Input shapes found in config file. Using input shapes from config with batch size {input_shapes['batch_size']} and sequence length {input_shapes['sequence_length']}"
)
else:
# Add warning if environment variables are set but will be ignored
if os.environ.get("HF_OPTIMUM_BATCH_SIZE", None) is not None:
logger.warning(
"HF_OPTIMUM_BATCH_SIZE environment variable is set. Environment variable will be ignored and input shapes from config file will be used."
)
if os.environ.get("HF_OPTIMUM_SEQUENCE_LENGTH", None) is not None:
logger.warning(
"HF_OPTIMUM_SEQUENCE_LENGTH environment variable is set. Environment variable will be ignored and input shapes from config file will be used."
)
except Exception:
input_shapes_available = False
# return input shapes if available
if input_shapes_available:
return input_shapes
# extract input shapes from environment variables
sequence_length = os.environ.get("HF_OPTIMUM_SEQUENCE_LENGTH", None)
if sequence_length is None:
raise ValueError(
"HF_OPTIMUM_SEQUENCE_LENGTH environment variable is not set. Please set HF_OPTIMUM_SEQUENCE_LENGTH to a positive integer."
)
if not int(sequence_length) > 0:
raise ValueError(
f"HF_OPTIMUM_SEQUENCE_LENGTH must be set to a positive integer. Current value is {sequence_length}"
)
batch_size = os.environ.get("HF_OPTIMUM_BATCH_SIZE", 1)
logger.info(
f"Using input shapes from environment variables with batch size {batch_size} and sequence length {sequence_length}"
)
return {"batch_size": int(batch_size), "sequence_length": int(sequence_length)}