in src/sagemaker_training/environment.py [0:0]
def __init__(self, resource_config=None, input_data_config=None, hyperparameters=None):
"""Initialize a read-only snapshot of the container environment.
Args:
resource_config (dict[string, object]): The contents from
/opt/ml/input/config/resourceconfig.json.
It has the following keys:
- current_host: The name of the current container on the container
network. For example, 'algo-1'.
- current_instance_type: Type of EC2 instance
- hosts: The list of names of all nodes on the container
network, sorted lexicographically. For example,
`['algo-1', 'algo-2', 'algo-3']` for a three-node cluster.
- current_instance_group: Name of the current instance group
- instance_groups: List of instance group dicts containing info about
instance_type, hosts list and group name
- network_interface_name: Name of network interface exposed to container
input_data_config (dict[string, object]): The contents from /opt/ml/input/config/inputdataconfig.json.
For example, suppose that you specify three data channels (train, evaluation, and
validation) in your request. This dictionary will contain:
{'train': {
'ContentType': 'trainingContentType',
'TrainingInputMode': 'File',
'S3DistributionType': 'FullyReplicated',
'RecordWrapperType': 'None'
},
'evaluation' : {
'ContentType': 'evalContentType',
'TrainingInputMode': 'File',
'S3DistributionType': 'FullyReplicated',
'RecordWrapperType': 'None'
},
'validation': {
'TrainingInputMode': 'File',
'S3DistributionType': 'FullyReplicated',
'RecordWrapperType': 'None'
}}
You can find more information about /opt/ml/input/config/inputdataconfig.json here:
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo-running-container.html#your-algorithms-training-algo-running-container-inputdataconfig
hyperparameters (dict[string, object]): An instance of `HyperParameters` containing the
training job hyperparameters.
"""
current_host = os.environ.get(params.CURRENT_HOST_ENV)
module_name = os.environ.get(params.USER_PROGRAM_ENV, None)
module_dir = os.environ.get(params.SUBMIT_DIR_ENV, code_dir)
log_level = int(os.environ.get(params.LOG_LEVEL_ENV, logging.INFO))
self._num_gpus = num_gpus()
self._num_cpus = num_cpus()
self._module_name = module_name
self._user_entry_point = module_name
self._module_dir = module_dir
self._log_level = log_level
self._model_dir = model_dir
resource_config = resource_config or read_resource_config()
input_data_config = input_data_config or read_input_data_config()
all_hyperparameters = hyperparameters or read_hyperparameters()
hosts = resource_config["hosts"]
current_instance_type = resource_config.get("current_instance_type", "local")
current_instance_group = resource_config.get("current_group_name", "homogeneousCluster")
current_host = resource_config["current_host"]
self._num_neurons = num_neurons(current_instance_type)
self._current_host = current_host
self._current_instance_type = current_instance_type
self._current_instance_group = current_instance_group
split_result = mapping.split_by_criteria(
all_hyperparameters,
keys=params.SAGEMAKER_HYPERPARAMETERS,
prefix=params.SAGEMAKER_PREFIX,
)
sagemaker_hyperparameters = split_result.included
additional_framework_parameters = {
k: sagemaker_hyperparameters[k]
for k in sagemaker_hyperparameters.keys()
if k not in params.SAGEMAKER_HYPERPARAMETERS
}
sagemaker_region = sagemaker_hyperparameters.get(
params.REGION_NAME_PARAM, boto3.session.Session().region_name
)
os.environ[params.JOB_NAME_ENV] = sagemaker_hyperparameters.get(params.JOB_NAME_PARAM, "")
os.environ[params.CURRENT_HOST_ENV] = current_host
os.environ[params.REGION_NAME_ENV] = sagemaker_region or ""
# hosts comprises of instances from all the groups
self._hosts = hosts
# eth0 is the default network interface defined by SageMaker with VPC support and
# local mode.
# ethwe is the current network interface defined by SageMaker training, it will be
# changed to eth0 in the short future.
self._network_interface_name = resource_config.get("network_interface_name", "eth0")
self._hyperparameters = split_result.excluded
self._additional_framework_parameters = additional_framework_parameters
self._resource_config = resource_config
self._input_data_config = input_data_config
self._output_data_dir = output_data_dir
self._output_intermediate_dir = output_intermediate_dir
self._channel_input_dirs = {channel: channel_path(channel) for channel in input_data_config}
# override base class attributes
if self._module_name is None:
self._module_name = str(sagemaker_hyperparameters.get(params.USER_PROGRAM_PARAM, None))
self._user_entry_point = self._user_entry_point or sagemaker_hyperparameters.get(
params.USER_PROGRAM_PARAM
)
self._module_dir = str(sagemaker_hyperparameters.get(params.SUBMIT_DIR_PARAM, code_dir))
self._log_level = sagemaker_hyperparameters.get(params.LOG_LEVEL_PARAM, logging.INFO)
self._sagemaker_s3_output = sagemaker_hyperparameters.get(
params.S3_OUTPUT_LOCATION_PARAM, None
)
self._framework_module = os.environ.get(params.FRAMEWORK_TRAINING_MODULE_ENV, None)
self._input_dir = input_dir
self._input_config_dir = input_config_dir
self._output_dir = output_dir
self._job_name = os.environ.get(params.TRAINING_JOB_ENV.upper(), None)
# Heterogeneous cluster changes - get the instance group related information
current_instance_group_hosts = self.get_current_instance_group_hosts()
instance_groups = self.get_instance_groups()
instance_groups_dict = self.get_instance_groups_dict()
distribution_instance_groups = self._additional_framework_parameters.get(
"sagemaker_distribution_instance_groups",
self.get_distribution_instance_groups_from_resource_config(),
)
self._distribution_instance_groups = distribution_instance_groups
distribution_hosts = self.get_distribution_hosts()
self._current_instance_group_hosts = current_instance_group_hosts
self._instance_groups = instance_groups
self._instance_groups_dict = instance_groups_dict
self._distribution_hosts = distribution_hosts
is_hetero = bool(len(self._instance_groups) > 1)
self._is_hetero = is_hetero
master_hostname = self.get_master_hostname()
self._master_hostname = master_hostname
self._is_master = current_host == self._master_hostname
self._distribution_enabled = bool(
self._current_instance_group in self._distribution_instance_groups
)
mp_parameters = os.environ.get(params.SM_HP_MP_PARAMETERS)
self._is_modelparallel_enabled = mp_parameters and mp_parameters != "{}"
self._is_smddprun_installed = validate_smddprun()
self._is_smddpmprun_installed = validate_smddpmprun()