src/sagemaker/modules/train/container_drivers/common/utils.py (119 lines of code) (raw):
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module provides utility functions for the container drivers."""
from __future__ import absolute_import
import os
import logging
import sys
import subprocess
import traceback
import json
from typing import List, Dict, Any, Tuple, IO, Optional
# Initialize logger
SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20)
logger = logging.getLogger(__name__)
console_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(console_handler)
logger.setLevel(int(SM_LOG_LEVEL))
FAILURE_FILE = "/opt/ml/output/failure"
DEFAULT_FAILURE_MESSAGE = """
Training Execution failed.
For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'.
TrainingJob - {training_job_name}
"""
USER_CODE_PATH = "/opt/ml/input/data/code"
SOURCE_CODE_JSON = "/opt/ml/input/data/sm_drivers/sourcecode.json"
DISTRIBUTED_JSON = "/opt/ml/input/data/sm_drivers/distributed.json"
HYPERPARAMETERS_JSON = "/opt/ml/input/config/hyperparameters.json"
SM_EFA_NCCL_INSTANCES = [
"ml.g4dn.8xlarge",
"ml.g4dn.12xlarge",
"ml.g5.48xlarge",
"ml.p3dn.24xlarge",
"ml.p4d.24xlarge",
"ml.p4de.24xlarge",
"ml.p5.48xlarge",
"ml.trn1.32xlarge",
]
SM_EFA_RDMA_INSTANCES = [
"ml.p4d.24xlarge",
"ml.p4de.24xlarge",
"ml.trn1.32xlarge",
]
def write_failure_file(message: Optional[str] = None):
"""Write a failure file with the message."""
if message is None:
message = DEFAULT_FAILURE_MESSAGE.format(training_job_name=os.environ["TRAINING_JOB_NAME"])
if not os.path.exists(FAILURE_FILE):
with open(FAILURE_FILE, "w") as f:
f.write(message)
def read_source_code_json(source_code_json: Dict[str, Any] = SOURCE_CODE_JSON):
"""Read the source code config json file."""
try:
with open(source_code_json, "r") as f:
source_code_dict = json.load(f) or {}
except FileNotFoundError:
source_code_dict = {}
return source_code_dict
def read_distributed_json(distributed_json: Dict[str, Any] = DISTRIBUTED_JSON):
"""Read the distribution config json file."""
try:
with open(distributed_json, "r") as f:
distributed_dict = json.load(f) or {}
except FileNotFoundError:
distributed_dict = {}
return distributed_dict
def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAMETERS_JSON):
"""Read the hyperparameters config json file."""
try:
with open(hyperparameters_json, "r") as f:
hyperparameters_dict = json.load(f) or {}
except FileNotFoundError:
hyperparameters_dict = {}
return hyperparameters_dict
def get_process_count(process_count: Optional[int] = None) -> int:
"""Get the number of processes to run on each node in the training job."""
return (
process_count
or int(os.environ.get("SM_NUM_GPUS", 0))
or int(os.environ.get("SM_NUM_NEURONS", 0))
or 1
)
def hyperparameters_to_cli_args(hyperparameters: Dict[str, Any]) -> List[str]:
"""Convert the hyperparameters to CLI arguments."""
cli_args = []
for key, value in hyperparameters.items():
value = safe_deserialize(value)
cli_args.extend([f"--{key}", safe_serialize(value)])
return cli_args
def safe_deserialize(data: Any) -> Any:
"""Safely deserialize data from a JSON string.
This function handles the following cases:
1. If `data` is not a string, it returns the input as-is.
2. If `data` is a string and matches common boolean values ("true" or "false"),
it returns the corresponding boolean value (True or False).
3. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`.
4. If `data` is a string but cannot be decoded as JSON, it returns the original string.
Returns:
Any: The deserialized data, or the original input if it cannot be JSON-decoded.
"""
if not isinstance(data, str):
return data
lower_data = data.lower()
if lower_data in ["true"]:
return True
if lower_data in ["false"]:
return False
try:
return json.loads(data)
except json.JSONDecodeError:
return data
def safe_serialize(data):
"""Serialize the data without wrapping strings in quotes.
This function handles the following cases:
1. If `data` is a string, it returns the string as-is without wrapping in quotes.
2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
the JSON-encoded string using `json.dumps()`.
3. If `data` cannot be serialized (e.g., a custom object), it returns the string
representation of the data using `str(data)`.
Args:
data (Any): The data to serialize.
Returns:
str: The serialized JSON-compatible string or the string representation of the input.
"""
if isinstance(data, str):
return data
try:
return json.dumps(data)
except TypeError:
return str(data)
def get_python_executable() -> str:
"""Get the python executable path."""
return sys.executable
def log_subprocess_output(pipe: IO[bytes]):
"""Log the output from the subprocess."""
for line in iter(pipe.readline, b""):
logger.info(line.decode("utf-8").strip())
def execute_commands(commands: List[str]) -> Tuple[int, str]:
"""Execute the provided commands and return exit code with failure traceback if any."""
try:
process = subprocess.Popen(
commands,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
with process.stdout:
log_subprocess_output(process.stdout)
exitcode = process.wait()
if exitcode != 0:
raise subprocess.CalledProcessError(exitcode, commands)
return exitcode, ""
except subprocess.CalledProcessError as e:
# Capture the traceback in case of failure
error_traceback = traceback.format_exc()
print(f"Command failed with exit code {e.returncode}. Traceback: {error_traceback}")
return e.returncode, error_traceback
def is_worker_node() -> bool:
"""Check if the current node is a worker node."""
return os.environ.get("SM_CURRENT_HOST") != os.environ.get("SM_MASTER_ADDR")
def is_master_node() -> bool:
"""Check if the current node is the master node."""
return os.environ.get("SM_CURRENT_HOST") == os.environ.get("SM_MASTER_ADDR")