in src/sagemaker/modules/train/model_trainer.py [0:0]
def model_post_init(self, __context: Any):
"""Post init method to perform custom validation and set default values."""
self._validate_training_image_and_algorithm_name(self.training_image, self.algorithm_name)
self._validate_source_code(self.source_code)
self._validate_distributed_config(self.source_code, self.distributed)
if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB:
if self.sagemaker_session is None:
self.sagemaker_session = Session()
logger.warning("SageMaker session not provided. Using default Session.")
if self.role is None:
self.role = get_execution_role(sagemaker_session=self.sagemaker_session)
logger.warning(f"Role not provided. Using default role:\n{self.role}")
if self.base_job_name is None:
if self.algorithm_name:
self.base_job_name = f"{self.algorithm_name}-job"
elif self.training_image:
self.base_job_name = f"{_get_repo_name_from_image(self.training_image)}-job"
logger.warning(f"Base name not provided. Using default name:\n{self.base_job_name}")
if self.compute is None:
self.compute = Compute(
instance_type=DEFAULT_INSTANCE_TYPE,
instance_count=1,
volume_size_in_gb=30,
)
logger.warning(f"Compute not provided. Using default:\n{self.compute}")
if self.stopping_condition is None:
self.stopping_condition = StoppingCondition(
max_runtime_in_seconds=3600,
max_pending_time_in_seconds=None,
max_wait_time_in_seconds=None,
)
logger.warning(
f"StoppingCondition not provided. Using default:\n{self.stopping_condition}"
)
if self.hyperparameters and isinstance(self.hyperparameters, str):
if not os.path.exists(self.hyperparameters):
raise ValueError(f"Hyperparameters file not found: {self.hyperparameters}")
logger.info(f"Loading hyperparameters from file: {self.hyperparameters}")
with open(self.hyperparameters, "r") as f:
contents = f.read()
try:
self.hyperparameters = json.loads(contents)
logger.debug("Hyperparameters loaded as JSON")
except json.JSONDecodeError:
try:
logger.info(f"contents: {contents}")
self.hyperparameters = yaml.safe_load(contents)
if not isinstance(self.hyperparameters, dict):
raise ValueError("YAML contents must be a valid mapping")
logger.info(f"hyperparameters: {self.hyperparameters}")
logger.debug("Hyperparameters loaded as YAML")
except (yaml.YAMLError, ValueError):
raise ValueError(
f"Invalid hyperparameters file: {self.hyperparameters}. "
"Must be a valid JSON or YAML file."
)
if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None:
session = self.sagemaker_session
base_job_name = self.base_job_name
self.output_data_config = OutputDataConfig(
s3_output_path=f"s3://{self._fetch_bucket_name_and_prefix(session)}"
f"/{base_job_name}",
compression_type="GZIP",
kms_key_id=None,
)
logger.warning(
f"OutputDataConfig not provided. Using default:\n{self.output_data_config}"
)
# TODO: Autodetect which image to use if source_code is provided
if self.training_image:
logger.info(f"Training image URI: {self.training_image}")