in src/sagemaker_training/environment.py [0:0]
def __init__(self, resource_config=None, input_data_config=None, hyperparameters=None):
"""Initialize a read-only snapshot of the container environment.
Args:
resource_config (dict[string, object]): The contents from
/opt/ml/input/config/resourceconfig.json.
It has the following keys:
- current_host: The name of the current container on the container network.
For example, 'algo-1'.
- hosts: The list of names of all containers on the container network,
sorted lexicographically. For example, `['algo-1', 'algo-2', 'algo-3']`
for a three-node cluster.
input_data_config (dict[string, object]): The contents from /opt/ml/input/config/inputdataconfig.json.
For example, suppose that you specify three data channels (train, evaluation, and
validation) in your request. This dictionary will contain:
{'train': {
'ContentType': 'trainingContentType',
'TrainingInputMode': 'File',
'S3DistributionType': 'FullyReplicated',
'RecordWrapperType': 'None'
},
'evaluation' : {
'ContentType': 'evalContentType',
'TrainingInputMode': 'File',
'S3DistributionType': 'FullyReplicated',
'RecordWrapperType': 'None'
},
'validation': {
'TrainingInputMode': 'File',
'S3DistributionType': 'FullyReplicated',
'RecordWrapperType': 'None'
}}
You can find more information about /opt/ml/input/config/inputdataconfig.json here:
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo-running-container.html#your-algorithms-training-algo-running-container-inputdataconfig
hyperparameters (dict[string, object]): An instance of `HyperParameters` containing the
training job hyperparameters.
"""
current_host = os.environ.get(params.CURRENT_HOST_ENV)
module_name = os.environ.get(params.USER_PROGRAM_ENV, None)
module_dir = os.environ.get(params.SUBMIT_DIR_ENV, code_dir)
log_level = int(os.environ.get(params.LOG_LEVEL_ENV, logging.INFO))
self._current_host = current_host
self._num_gpus = num_gpus()
self._num_cpus = num_cpus()
self._module_name = module_name
self._user_entry_point = module_name
self._module_dir = module_dir
self._log_level = log_level
self._model_dir = model_dir
resource_config = resource_config or read_resource_config()
input_data_config = input_data_config or read_input_data_config()
all_hyperparameters = hyperparameters or read_hyperparameters()
current_host = resource_config["current_host"]
hosts = resource_config["hosts"]
split_result = mapping.split_by_criteria(
all_hyperparameters,
keys=params.SAGEMAKER_HYPERPARAMETERS,
prefix=params.SAGEMAKER_PREFIX,
)
sagemaker_hyperparameters = split_result.included
additional_framework_parameters = {
k: sagemaker_hyperparameters[k]
for k in sagemaker_hyperparameters.keys()
if k not in params.SAGEMAKER_HYPERPARAMETERS
}
sagemaker_region = sagemaker_hyperparameters.get(
params.REGION_NAME_PARAM, boto3.session.Session().region_name
)
os.environ[params.JOB_NAME_ENV] = sagemaker_hyperparameters.get(params.JOB_NAME_PARAM, "")
os.environ[params.CURRENT_HOST_ENV] = current_host
os.environ[params.REGION_NAME_ENV] = sagemaker_region or ""
self._hosts = hosts
# eth0 is the default network interface defined by SageMaker with VPC support and
# local mode.
# ethwe is the current network interface defined by SageMaker training, it will be
# changed to eth0 in the short future.
self._network_interface_name = resource_config.get("network_interface_name", "eth0")
self._hyperparameters = split_result.excluded
self._additional_framework_parameters = additional_framework_parameters
self._resource_config = resource_config
self._input_data_config = input_data_config
self._output_data_dir = output_data_dir
self._output_intermediate_dir = output_intermediate_dir
self._channel_input_dirs = {channel: channel_path(channel) for channel in input_data_config}
self._current_host = current_host
# override base class attributes
if self._module_name is None:
self._module_name = str(sagemaker_hyperparameters.get(params.USER_PROGRAM_PARAM, None))
self._user_entry_point = self._user_entry_point or sagemaker_hyperparameters.get(
params.USER_PROGRAM_PARAM
)
self._module_dir = str(sagemaker_hyperparameters.get(params.SUBMIT_DIR_PARAM, code_dir))
self._log_level = sagemaker_hyperparameters.get(params.LOG_LEVEL_PARAM, logging.INFO)
self._sagemaker_s3_output = sagemaker_hyperparameters.get(
params.S3_OUTPUT_LOCATION_PARAM, None
)
self._framework_module = os.environ.get(params.FRAMEWORK_TRAINING_MODULE_ENV, None)
self._input_dir = input_dir
self._input_config_dir = input_config_dir
self._output_dir = output_dir
self._job_name = os.environ.get(params.TRAINING_JOB_ENV.upper(), None)
self._master_hostname = list(hosts)[0]
self._is_master = current_host == self._master_hostname