src/sagemaker/serve/utils/optimize_utils.py (248 lines of code) (raw):
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Holds the util functions used for the optimize function"""
from __future__ import absolute_import
import re
import logging
from typing import Dict, Any, Optional, Union, List, Tuple
from sagemaker import Model, Session
from sagemaker.enums import Tag
from sagemaker.jumpstart.utils import accessors, get_eula_message
logger = logging.getLogger(__name__)
SPECULATIVE_DRAFT_MODEL = "/opt/ml/additional-model-data-sources"
def _is_inferentia_or_trainium(instance_type: Optional[str]) -> bool:
"""Checks whether an instance is compatible with Inferentia.
Args:
instance_type (str): The instance type used for the compilation job.
Returns:
bool: Whether the given instance type is Inferentia or Trainium.
"""
if isinstance(instance_type, str):
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
if match:
if match[1].startswith("inf") or match[1].startswith("trn"):
return True
return False
def _is_image_compatible_with_optimization_job(image_uri: Optional[str]) -> bool:
"""Checks whether an instance is compatible with an optimization job.
Args:
image_uri (str): The image URI of the optimization job.
Returns:
bool: Whether the given instance type is compatible with an optimization job.
"""
# TODO: Use specific container type instead.
if image_uri is None:
return True
return "djl-inference:" in image_uri and ("-lmi" in image_uri or "-neuronx-" in image_uri)
def _deployment_config_contains_draft_model(deployment_config: Optional[Dict]) -> bool:
"""Checks whether a deployment config contains a speculative decoding draft model.
Args:
deployment_config (Dict): The deployment config to check.
Returns:
bool: Whether the deployment config contains a draft model or not.
"""
if deployment_config is None:
return False
deployment_args = deployment_config.get("DeploymentArgs", {})
additional_data_sources = deployment_args.get("AdditionalDataSources")
return "speculative_decoding" in additional_data_sources if additional_data_sources else False
def _is_draft_model_jumpstart_provided(deployment_config: Optional[Dict]) -> bool:
"""Checks whether a deployment config's draft model is provided by JumpStart.
Args:
deployment_config (Dict): The deployment config to check.
Returns:
bool: Whether the draft model is provided by JumpStart or not.
"""
if deployment_config is None:
return False
additional_model_data_sources = deployment_config.get("DeploymentArgs", {}).get(
"AdditionalDataSources"
)
for source in additional_model_data_sources.get("speculative_decoding", []):
if source["channel_name"] == "draft_model":
if source.get("provider", {}).get("name") == "JumpStart":
return True
continue
return False
def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) -> Model:
"""Generates a new optimization model.
Args:
pysdk_model (Model): A PySDK model.
optimization_response (dict): The optimization response.
Returns:
Model: A deployable optimized model.
"""
recommended_image_uri = optimization_response.get("OptimizationOutput", {}).get(
"RecommendedInferenceImage"
)
s3_uri = optimization_response.get("OutputConfig", {}).get("S3OutputLocation")
deployment_instance_type = optimization_response.get("DeploymentInstanceType")
if recommended_image_uri:
pysdk_model.image_uri = recommended_image_uri
if s3_uri:
pysdk_model.model_data["S3DataSource"]["S3Uri"] = s3_uri
if deployment_instance_type:
pysdk_model.instance_type = deployment_instance_type
pysdk_model.add_tags(
{"Key": Tag.OPTIMIZATION_JOB_NAME, "Value": optimization_response["OptimizationJobName"]}
)
return pysdk_model
def _is_optimized(pysdk_model: Model) -> bool:
"""Checks whether an optimization model is optimized.
Args:
pysdk_model (Model): A PySDK model.
Return:
bool: Whether the given model type is optimized.
"""
optimized_tags = [Tag.OPTIMIZATION_JOB_NAME, Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER]
if hasattr(pysdk_model, "_tags") and pysdk_model._tags:
if isinstance(pysdk_model._tags, dict):
return pysdk_model._tags.get("Key") in optimized_tags
for tag in pysdk_model._tags:
if tag.get("Key") in optimized_tags:
return True
return False
def _generate_model_source(
model_data: Optional[Union[Dict[str, Any], str]], accept_eula: Optional[bool]
) -> Optional[Dict[str, Any]]:
"""Extracts model source from model data.
Args:
model_data (Optional[Union[Dict[str, Any], str]]): A model data.
Returns:
Optional[Dict[str, Any]]: Model source data.
"""
if model_data is None:
raise ValueError("Model Optimization Job only supports model with S3 data source.")
s3_uri = model_data
if isinstance(s3_uri, dict):
s3_uri = s3_uri.get("S3DataSource").get("S3Uri")
model_source = {"S3": {"S3Uri": s3_uri}}
if accept_eula:
model_source["S3"]["ModelAccessConfig"] = {"AcceptEula": True}
return model_source
def _update_environment_variables(
env: Optional[Dict[str, str]], new_env: Optional[Dict[str, str]]
) -> Optional[Dict[str, str]]:
"""Updates environment variables based on environment variables.
Args:
env (Optional[Dict[str, str]]): The environment variables.
new_env (Optional[Dict[str, str]]): The new environment variables.
Returns:
Optional[Dict[str, str]]: The updated environment variables.
"""
if new_env:
if env:
env.update(new_env)
else:
env = new_env
return env
def _extract_speculative_draft_model_provider(
speculative_decoding_config: Optional[Dict] = None,
) -> Optional[str]:
"""Extracts speculative draft model provider from speculative decoding config.
Args:
speculative_decoding_config (Optional[Dict]): A speculative decoding config.
Returns:
Optional[str]: The speculative draft model provider.
"""
if speculative_decoding_config is None:
return None
model_provider = speculative_decoding_config.get("ModelProvider", "").lower()
if model_provider == "jumpstart":
return "jumpstart"
if model_provider == "custom" or speculative_decoding_config.get("ModelSource"):
return "custom"
if model_provider == "sagemaker":
return "sagemaker"
return "auto"
def _extract_additional_model_data_source_s3_uri(
additional_model_data_source: Optional[Dict] = None,
) -> Optional[str]:
"""Extracts model data source s3 uri from a model data source in Pascal case.
Args:
additional_model_data_source (Optional[Dict]): A model data source.
Returns:
str: S3 uri of the model resources.
"""
if (
additional_model_data_source is None
or additional_model_data_source.get("S3DataSource", None) is None
):
return None
return additional_model_data_source.get("S3DataSource").get("S3Uri")
def _extract_deployment_config_additional_model_data_source_s3_uri(
additional_model_data_source: Optional[Dict] = None,
) -> Optional[str]:
"""Extracts model data source s3 uri from a model data source in snake case.
Args:
additional_model_data_source (Optional[Dict]): A model data source.
Returns:
str: S3 uri of the model resources.
"""
if (
additional_model_data_source is None
or additional_model_data_source.get("s3_data_source", None) is None
):
return None
return additional_model_data_source.get("s3_data_source").get("s3_uri", None)
def _is_draft_model_gated(
draft_model_config: Optional[Dict] = None,
) -> bool:
"""Extracts model gated-ness from draft model data source.
Args:
draft_model_config (Optional[Dict]): A model data source.
Returns:
bool: Whether the draft model is gated or not.
"""
return "hosting_eula_key" in draft_model_config if draft_model_config else False
def _extracts_and_validates_speculative_model_source(
speculative_decoding_config: Dict,
) -> str:
"""Extracts model source from speculative decoding config.
Args:
speculative_decoding_config (Optional[Dict]): A speculative decoding config.
Returns:
str: Model source.
Raises:
ValueError: If model source is none.
"""
model_source: str = speculative_decoding_config.get("ModelSource")
if not model_source:
raise ValueError("ModelSource must be provided in speculative decoding config.")
return model_source
def _generate_channel_name(additional_model_data_sources: Optional[List[Dict]]) -> str:
"""Generates a channel name.
Args:
additional_model_data_sources (Optional[List[Dict]]): The additional model data sources.
Returns:
str: The channel name.
"""
channel_name = "draft_model"
if additional_model_data_sources and len(additional_model_data_sources) > 0:
channel_name = additional_model_data_sources[0].get("ChannelName", channel_name)
return channel_name
def _generate_additional_model_data_sources(
model_source: str,
channel_name: str,
accept_eula: bool = False,
s3_data_type: Optional[str] = "S3Prefix",
compression_type: Optional[str] = "None",
) -> List[Dict]:
"""Generates additional model data sources.
Args:
model_source (Optional[str]): The model source.
channel_name (Optional[str]): The channel name.
accept_eula (Optional[bool]): Whether to accept eula or not.
s3_data_type (Optional[str]): The S3 data type, defaults to 'S3Prefix'.
compression_type (Optional[str]): The compression type, defaults to None.
Returns:
List[Dict]: The additional model data sources.
"""
additional_model_data_source = {
"ChannelName": channel_name,
"S3DataSource": {
"S3Uri": model_source,
"S3DataType": s3_data_type,
"CompressionType": compression_type,
},
}
if accept_eula:
additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = {"AcceptEula": True}
return [additional_model_data_source]
def _is_s3_uri(s3_uri: Optional[str]) -> bool:
"""Checks whether an S3 URI is valid.
Args:
s3_uri (Optional[str]): The S3 URI.
Returns:
bool: Whether the S3 URI is valid.
"""
if s3_uri is None:
return False
return re.match("^s3://([^/]+)/?(.*)$", s3_uri) is not None
def _extract_optimization_config_and_env(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]:
"""Extracts optimization config and environment variables.
Args:
quantization_config (Optional[Dict]): The quantization config.
compilation_config (Optional[Dict]): The compilation config.
sharding_config (Optional[Dict]): The sharding config.
Returns:
Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]:
The optimization config and environment variables.
"""
optimization_config = {}
quantization_override_env = (
quantization_config.get("OverrideEnvironment") if quantization_config else None
)
compilation_override_env = (
compilation_config.get("OverrideEnvironment") if compilation_config else None
)
sharding_override_env = sharding_config.get("OverrideEnvironment") if sharding_config else None
if quantization_config is not None:
optimization_config["ModelQuantizationConfig"] = quantization_config
if compilation_config is not None:
optimization_config["ModelCompilationConfig"] = compilation_config
if sharding_config is not None:
optimization_config["ModelShardingConfig"] = sharding_config
# Return optimization config dict and environment variables if either is present
if optimization_config:
return (
optimization_config,
quantization_override_env,
compilation_override_env,
sharding_override_env,
)
return None, None, None, None
def _custom_speculative_decoding(
model: Model,
speculative_decoding_config: Optional[Dict],
accept_eula: Optional[bool] = False,
) -> Model:
"""Modifies the given model for speculative decoding config with custom provider.
Args:
model (Model): The model.
speculative_decoding_config (Optional[Dict]): The speculative decoding config.
accept_eula (Optional[bool]): Whether to accept eula or not.
"""
if speculative_decoding_config:
additional_model_source = _extracts_and_validates_speculative_model_source(
speculative_decoding_config
)
accept_eula = speculative_decoding_config.get("AcceptEula", accept_eula)
if _is_s3_uri(additional_model_source):
channel_name = _generate_channel_name(model.additional_model_data_sources)
speculative_draft_model = f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}"
model.additional_model_data_sources = _generate_additional_model_data_sources(
additional_model_source, channel_name, accept_eula
)
else:
speculative_draft_model = additional_model_source
model.env.update({"OPTION_SPECULATIVE_DRAFT_MODEL": speculative_draft_model})
model.add_tags(
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "custom"},
)
return model
def _jumpstart_speculative_decoding(
model=Model,
speculative_decoding_config: Optional[Dict[str, Any]] = None,
sagemaker_session: Optional[Session] = None,
):
"""Modifies the given model for speculative decoding config with JumpStart provider.
Args:
model (Model): The model.
speculative_decoding_config (Optional[Dict]): The speculative decoding config.
sagemaker_session (Optional[Session]): Sagemaker session for execution.
"""
if speculative_decoding_config:
js_id = speculative_decoding_config.get("ModelID")
if not js_id:
raise ValueError(
"`ModelID` is a required field in `speculative_decoding_config` when "
"using JumpStart as draft model provider."
)
model_version = speculative_decoding_config.get("ModelVersion", "*")
accept_eula = speculative_decoding_config.get("AcceptEula", False)
channel_name = _generate_channel_name(model.additional_model_data_sources)
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
model_id=js_id,
version=model_version,
region=sagemaker_session.boto_region_name,
sagemaker_session=sagemaker_session,
)
model_spec_json = model_specs.to_json()
js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket()
if model_spec_json.get("gated_bucket", False):
if not accept_eula:
eula_message = get_eula_message(
model_specs=model_specs, region=sagemaker_session.boto_region_name
)
raise ValueError(
f"{eula_message} Set `AcceptEula`=True in "
f"speculative_decoding_config once acknowledged."
)
js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket()
key_prefix = model_spec_json.get("hosting_prepacked_artifact_key")
model.additional_model_data_sources = _generate_additional_model_data_sources(
f"s3://{js_bucket}/{key_prefix}",
channel_name,
accept_eula,
)
model.env.update(
{"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}/"}
)
model.add_tags(
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "jumpstart"},
)