in src/sagemaker/estimator.py [0:0]
def _prepare_for_training(self, job_name=None):
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
Args:
* job_name (str): Name of the training job to be created. If not
specified, one is generated, using the base name given to the
constructor if applicable.
"""
super(Framework, self)._prepare_for_training(job_name=job_name)
if self.git_config:
updated_paths = git_utils.git_clone_repo(
self.git_config, self.entry_point, self.source_dir, self.dependencies
)
self.entry_point = updated_paths["entry_point"]
self.source_dir = updated_paths["source_dir"]
self.dependencies = updated_paths["dependencies"]
# validate source dir will raise a ValueError if there is something wrong with the
# source directory. We are intentionally not handling it because this is a critical error.
if self.source_dir and not self.source_dir.lower().startswith("s3://"):
validate_source_dir(self.entry_point, self.source_dir)
# if we are in local mode with local_code=True. We want the container to just
# mount the source dir instead of uploading to S3.
local_code = get_config_value("local.local_code", self.sagemaker_session.config)
if self.sagemaker_session.local_mode and local_code:
# if there is no source dir, use the directory containing the entry point.
if self.source_dir is None:
self.source_dir = os.path.dirname(self.entry_point)
self.entry_point = os.path.basename(self.entry_point)
code_dir = "file://" + self.source_dir
script = self.entry_point
elif self.enable_network_isolation() and self.entry_point:
self.uploaded_code = self._stage_user_code_in_s3()
code_dir = self.CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH
script = self.uploaded_code.script_name
self.code_uri = self.uploaded_code.s3_prefix
else:
self.uploaded_code = self._stage_user_code_in_s3()
code_dir = self.uploaded_code.s3_prefix
script = self.uploaded_code.script_name
# Modify hyperparameters in-place to point to the right code directory and script URIs
self._hyperparameters[DIR_PARAM_NAME] = code_dir
self._hyperparameters[SCRIPT_PARAM_NAME] = script
self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
self._validate_and_set_debugger_configs()