in src/sagemaker/estimator.py [0:0]
def _prepare_for_training(self, job_name=None):
"""Set any values in the estimator that need to be set before training.
Args:
job_name (str): Name of the training job to be created. If not
specified, one is generated, using the base name given to the
constructor if applicable.
"""
self._current_job_name = self._get_or_create_name(job_name)
# if output_path was specified we use it otherwise initialize here.
# For Local Mode with local_code=True we don't need an explicit output_path
if self.output_path is None:
local_code = get_config_value("local.local_code", self.sagemaker_session.config)
if self.sagemaker_session.local_mode and local_code:
self.output_path = ""
else:
self.output_path = s3.s3_path_join(
"s3://",
self.sagemaker_session.default_bucket(),
self.sagemaker_session.default_bucket_prefix,
with_end_slash=True,
)
self._is_output_path_set_from_default_bucket_and_prefix = True
if self.git_config:
updated_paths = git_utils.git_clone_repo(
self.git_config, self.entry_point, self.source_dir, self.dependencies
)
self.entry_point = updated_paths["entry_point"]
self.source_dir = updated_paths["source_dir"]
self.dependencies = updated_paths["dependencies"]
if self.source_dir or self.entry_point or self.dependencies:
# validate source dir will raise a ValueError if there is something wrong with
# the source directory. We are intentionally not handling it because this is a
# critical error.
if (
self.source_dir
and not is_pipeline_variable(self.source_dir)
and not self.source_dir.lower().startswith("s3://")
):
validate_source_dir(self.entry_point, self.source_dir)
# if we are in local mode with local_code=True. We want the container to just
# mount the source dir instead of uploading to S3.
local_code = get_config_value("local.local_code", self.sagemaker_session.config)
if self.sagemaker_session.local_mode and local_code:
# if there is no source dir, use the directory containing the entry point.
if self.source_dir is None:
self.source_dir = os.path.dirname(self.entry_point)
self.entry_point = os.path.basename(self.entry_point)
code_dir = "file://" + self.source_dir
script = self.entry_point
elif self.enable_network_isolation() and self.entry_point:
self.uploaded_code = self._stage_user_code_in_s3()
code_dir = self.CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH
script = self.uploaded_code.script_name
self.code_uri = self.uploaded_code.s3_prefix
else:
self.uploaded_code = self._stage_user_code_in_s3()
code_dir = self.uploaded_code.s3_prefix
script = self.uploaded_code.script_name
# Modify hyperparameters in-place to point to the right code directory and
# script URIs
self._script_mode_hyperparam_update(code_dir, script)
self._prepare_rules()
self._prepare_debugger_for_training()
self._prepare_profiler_for_training()