src/sagemaker_pytorch_container/training.py (98 lines of code) (raw):
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import os
import logging
from retrying import retry
import six
import socket
import sys
from sagemaker_training import entry_point, environment, errors, runner
MASTER_PORT = '7777'
LAUNCH_SMDATAPARALLEL_ENV_NAME = 'sagemaker_distributed_dataparallel_enabled'
LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled'
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
LAUNCH_PYTORCH_XLA_ENV_NAME = "sagemaker_pytorch_xla_multi_worker_enabled"
LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
logger = logging.getLogger(__name__)
def train(training_environment):
"""Run PyTorch training on a user supplied module.
The user supplied module is run in either a local or distributed SageMaker
environment.
The user supplied module and its dependencies are downloaded from S3.
Training is invoked by calling a "train" function in the user supplied module.
if the environment contains multiple hosts, then a distributed learning
task is started.
Args:
training_environment: training environment object containing environment
variables, training arguments and hyperparameters.
"""
_sm_studio_local_mode = os.environ.get("SM_STUDIO_LOCAL_MODE", "False").lower() == "true"
if not _sm_studio_local_mode:
# Block until all host DNS lookups succeed. Relies on retrying dns_lookup.
logger.info('Block until all host DNS lookups succeed.')
for host in training_environment.hosts:
_dns_lookup(host)
else:
logger.info('Bypass DNS check in case of Studio Local Mode execution.')
_set_nccl_environment(training_environment.network_interface_name)
_set_distributed_environment(training_environment)
mpi_enabled = training_environment.additional_framework_parameters.get(LAUNCH_MPI_ENV_NAME)
pytorch_ddp_enabled = training_environment.additional_framework_parameters.get(
LAUNCH_PYTORCH_DDP_ENV_NAME, False
)
smdataparallel_enabled = training_environment.additional_framework_parameters.get(
LAUNCH_SMDATAPARALLEL_ENV_NAME, False
)
pytorch_xla_enabled = training_environment.additional_framework_parameters.get(
LAUNCH_PYTORCH_XLA_ENV_NAME, False
)
torch_distributed_enabled = training_environment.additional_framework_parameters.get(
LAUNCH_TORCH_DISTRIBUTED_ENV_NAME, False
)
# default scenario
runner_type = runner.ProcessRunnerType
if training_environment.current_instance_group in training_environment.distribution_instance_groups:
if mpi_enabled:
runner_type = runner.MPIRunnerType
elif pytorch_ddp_enabled:
runner_type = runner.SMDataParallelRunnerType
logger.info('Invoking SMDataParallel for native PT DDP job')
elif torch_distributed_enabled:
runner_type = runner.TorchDistributedRunnerType
logger.info('Invoking TorchDistributed...')
elif smdataparallel_enabled:
runner_type = runner.SMDataParallelRunnerType
logger.info('Invoking SMDataParallel')
elif pytorch_xla_enabled:
runner_type = runner.PyTorchXLARunnerType
logger.info('Invoking PT-XLA Runner')
logger.info('Invoking user training script.')
# get capture_error from framework parameters
capture_error = True
if training_environment.additional_framework_parameters.get("sagemaker_toolkit_native_launcher_enabled"):
capture_error = False
logger.info(f'capture_error is {capture_error}. Default is True')
_set_torch_version_environment()
try:
entry_point.run(uri=training_environment.module_dir,
user_entry_point=training_environment.user_entry_point,
args=training_environment.to_cmd_args(),
env_vars=training_environment.to_env_vars(),
capture_error=capture_error,
runner_type=runner_type)
except errors.ExecuteUserScriptError as err:
message = str(err)
if message.find('terminate called after throwing an instance of \'gloo::EnforceNotMet\'') > -1:
logger.warn('Known exception: {}'.format(message))
else:
info = sys.exc_info()
six.reraise(info[0], err, info[2])
@retry(stop_max_delay=1000 * 60 * 15,
wait_exponential_multiplier=100,
wait_exponential_max=30000)
def _dns_lookup(host):
"""Retry DNS lookup on host."""
return socket.gethostbyname(host)
def _set_distributed_environment(training_env):
"""Set environment variable for distributed training.
Args:
hosts: list of hosts that are used for training.
"""
# According to https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html
# hosts are sorted lexicographically.
os.environ['MASTER_ADDR'] = training_env.master_hostname
os.environ['MASTER_PORT'] = MASTER_PORT
def _set_nccl_environment(network_interface_name):
"""Set NCCL environment variables for the container.
https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/index.html#ncclknobs
Args:
network_interface_name: The name of the network interface to use for
distributed training.
"""
# Set the network interface for inter node communication
os.environ['NCCL_SOCKET_IFNAME'] = network_interface_name
# Disable IB transport and force to use IP sockets by default
os.environ['NCCL_IB_DISABLE'] = '1'
# Set to INFO for more NCCL debugging information
os.environ['NCCL_DEBUG'] = 'WARN'
def _set_torch_version_environment():
"""Set PyTorch version environment variable.
This is the PyTorch version of the DLC.
"""
try:
import torch
os.environ["SM_DLC_TORCH_VERSION"] = torch.__version__
except ModuleNotFoundError:
logger.warn("PyTorch cannot be found")
except ImportError:
logger.warn("PyTorch can be found, but cannot be imported")
except Exception:
logger.warn("Torch version environment variable cannot be set")
def main():
train(environment.Environment())