def main()

in template/sm_jobs.py [0:0]


def main():
    args = parse_args()

    sagemaker_session = sagemaker.Session()
    role = sagemaker.get_execution_role()

    sm_jobs_config = OmegaConf.load(args.sm_jobs_config)
    recipe_overrides = sm_jobs_config.get("recipe_overrides", omegaconf.DictConfig(dict()))
    recipe = OmegaConf.load(args.recipe)
    recipe = OmegaConf.merge(recipe, recipe_overrides)
    recipe_overrides = OmegaConf.to_container(recipe_overrides)

    sm_inputs = sm_jobs_config.get("inputs")
    inputs = None
    if sm_inputs:
        s3 = sm_inputs.get("s3")
        file_system = sm_inputs.get("file_system")
        if s3 and file_system:
            raise ValueError("Must set only one of s3 or file_system in sm_jobs_config.inputs.")
        if s3 is None and file_system is None:
            raise ValueError("Must set either s3 or file_system in sm_jobs_config.inputs.")
        if s3:
            inputs = OmegaConf.to_container(s3)
        else:
            file_system_id = file_system.get("id")
            file_system_type = file_system.get("type")
            directory_path = file_system.get("directory_path")
            if file_system_id is None or file_system_type is None or directory_path is None:
                raise ValueError("Must set id, type and directory_path for file_system input type in sm_jobs_config.")
            inputs = FileSystemInput(
                file_system_id=file_system_id,
                file_system_type=file_system_type,
                directory_path=directory_path,
                file_system_access_mode="ro",
            )

    output_path = sm_jobs_config.get("output_path")
    if output_path is None:
        raise ValueError("Expected output_path to be set with sm_jobs cluster type")

    additional_estimator_kwargs = sm_jobs_config.get("additional_estimator_kwargs", omegaconf.DictConfig(dict()))
    additional_estimator_kwargs = OmegaConf.to_container(additional_estimator_kwargs)

    tensorboard_config = sm_jobs_config.get("tensorboard_config")
    if tensorboard_config:
        tb_output_path = tensorboard_config.get("output_path")
        tb_container_path = tensorboard_config.get("container_logs_path")
        if tb_output_path is None or tb_container_path is None:
            raise ValueError("Please set output path and container path when using tensorboard.")
        tensorboard_output_config = TensorBoardOutputConfig(
            s3_output_path=tb_output_path, container_local_output_path=tb_container_path
        )
        additional_estimator_kwargs["tensorboard_output_config"] = tensorboard_output_config
        if recipe.get("exp_manager") is None or recipe.get("exp_manager", dict()).get("explicit_log_dir") is None:
            logger.warning("Using tensorboard but not set exp_manager -> explicit_log_dir for recipe.")

    base_job_name = args.job_name.replace(".", "-")
    base_job_name = base_job_name.replace("_", "-")
    estimator = PyTorch(
        base_job_name=base_job_name,
        instance_type=args.instance_type,
        training_recipe=args.recipe,
        recipe_overrides=recipe_overrides,
        output_path=output_path,
        role=role,
        sagemaker_session=sagemaker_session,
        **additional_estimator_kwargs,
    )

    if tensorboard_config:
        logger.info("Tensorboard url:")
        logger.info(
            estimator.get_app_url(
                app_type=SupportedInteractiveAppTypes.TENSORBOARD,
                open_in_default_web_browser=False,
            )
        )

    if not isinstance(inputs, FileSystemInput):
        keys_to_pop = []
        for item in inputs.keys():
            if not inputs[item]:
                print(f"poping input {inputs[item]}, {item}")
                keys_to_pop.append(item)
        for item in keys_to_pop:
            inputs.pop(item)
        if len(inputs) == 0:
            inputs = None

    estimator.fit(inputs=inputs, wait=sm_jobs_config.get("wait", False))