def _setup_for_training_recipe()

in src/sagemaker/pytorch/estimator.py [0:0]


    def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, source_dir, kwargs):
        """Performs training recipe specific setup and returns recipe specific args.

        Updates kwargs and returns a dictionary of args to use for estimator
        initialization and setup when using a training recipe. Updates the paths in
        the recipe for Sagemaker Jobs environment.

        Args:
            training_recipe (str): A recipe which is a local file path, a url or a
                                   sagemaker training recipe.
            recipe_overrides (Dict): Dictionary specifying key values to override in the
            source_dir (str): Path (absolute, or relative) to a directory where to copy
                              the scripts for training recipe. requirements.txt can also
                              go here.
            kwargs (dict): Dictionary of args used for estimator initializaiton.
        Returns:
            dict containing arg values for estimator initialization and setup.

        """
        if kwargs.get("sagemaker_session") is not None:
            region_name = kwargs.get("sagemaker_session").boto_region_name
        else:
            region_name = Session().boto_region_name

        training_recipes_cfg_filename = os.path.join(
            os.path.dirname(__file__), "training_recipes.json"
        )
        with open(training_recipes_cfg_filename) as training_recipes_cfg_file:
            training_recipes_cfg = json.load(training_recipes_cfg_file)

        if recipe_overrides is None:
            recipe_overrides = dict()
        recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_")
        recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_")
        args = dict()
        if source_dir is None:
            args["source_dir"] = "."
        else:
            if not os.path.exists(source_dir):
                raise ValueError(
                    "When using training_recipe, source_dir must be a local directory."
                )
            args["source_dir"] = source_dir

        recipe_name = os.path.splitext(os.path.basename(training_recipe))[0]
        temp_local_recipe = tempfile.NamedTemporaryFile(prefix=recipe_name, suffix=".yaml").name
        if training_recipe.endswith(".yaml"):
            if os.path.isfile(training_recipe):
                shutil.copy(training_recipe, temp_local_recipe)
            else:
                try:
                    urlretrieve(training_recipe, temp_local_recipe)
                except Exception as e:
                    raise ValueError(
                        f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}"
                    )
        else:
            launcher_repo = os.environ.get(
                "TRAINING_LAUNCHER_GIT", None
            ) or training_recipes_cfg.get("launcher_repo")
            _run_clone_command(launcher_repo, recipe_launcher_dir.name)
            recipe = os.path.join(
                recipe_launcher_dir.name,
                "recipes_collection",
                "recipes",
                training_recipe + ".yaml",
            )
            if os.path.isfile(recipe):
                shutil.copy(recipe, temp_local_recipe)
            else:
                raise ValueError(f"Recipe {training_recipe} not found.")

        recipe = OmegaConf.load(temp_local_recipe)
        os.unlink(temp_local_recipe)
        recipe = OmegaConf.merge(recipe, recipe_overrides)

        if "instance_type" not in kwargs:
            raise ValueError("Must pass instance type to estimator when using training recipes.")
        instance_type = kwargs["instance_type"].split(".")[1]
        if instance_type.startswith(("p", "g")):
            device_type = "gpu"
        elif instance_type.startswith("trn"):
            device_type = "trainium"
        else:
            device_type = "cpu"

        if "trainer" not in recipe:
            raise ValueError("Supplied recipe does not contain required field trainer.")
        if "instance_count" in kwargs and "num_nodes" in recipe["trainer"]:
            logger.warning(
                "Using instance_count argument to estimator to set number "
                " of nodes. Ignoring trainer -> num_nodes in recipe."
            )
        if "instance_count" not in kwargs:
            if "num_nodes" not in recipe["trainer"]:
                raise ValueError(
                    "Must set either instance_count argument for estimator or"
                    "set trainer -> num_nodes in recipe."
                )
            kwargs["instance_count"] = recipe["trainer"]["num_nodes"]

        # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
        # to retrieve the image uri below before we go GA.
        if device_type == "gpu":
            adapter_repo = os.environ.get("TRAINING_ADAPTER_GIT", None) or training_recipes_cfg.get(
                "adapter_repo"
            )
            _run_clone_command(adapter_repo, recipe_train_dir.name)
            script = _get_training_recipe_gpu_script(
                recipe_train_dir.name, recipe, args["source_dir"]
            )
            args["default_image_uri"] = _get_training_recipe_image_uri(
                training_recipes_cfg.get("gpu_image"), region_name
            )
            smp_options = {
                "enabled": True,
                "parameters": {
                    "placement_strategy": "cluster",
                },
            }
            args["distribution"] = {
                "smdistributed": {"modelparallel": smp_options},
                "torch_distributed": {"enabled": True},
            }
        elif device_type == "trainium":
            _run_clone_command(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir.name)
            script = _get_training_recipe_trainium_script(recipe_train_dir.name, args["source_dir"])
            args["default_image_uri"] = _get_training_recipe_image_uri(
                training_recipes_cfg.get("neuron_image"), region_name
            )
            args["distribution"] = {
                "torch_distributed": {"enabled": True},
            }
        else:
            raise ValueError(
                f"Devices of type {device_type} are not supported with training recipes."
            )
        args["entry_point"] = os.path.basename(script)

        recipe_train_dir.cleanup()
        recipe_launcher_dir.cleanup()

        if "container" in recipe and not recipe["container"]:
            logger.warning(
                "Ignoring container from training_recipe. Use image_uri arg for estimator."
            )

        _setup_omegaconf_resolvers()
        final_recipe = _try_resolve_recipe(recipe)
        if final_recipe is None:
            final_recipe = _try_resolve_recipe(recipe, "recipes")
        if final_recipe is None:
            final_recipe = _try_resolve_recipe(recipe, "training")
        if final_recipe is None:
            raise RuntimeError("Could not resolve provided recipe.")
        cls.training_recipe_file = tempfile.NamedTemporaryFile(
            dir=args["source_dir"],
            prefix=recipe_name + "_",
            suffix=".yaml",
        )
        OmegaConf.save(config=final_recipe, f=cls.training_recipe_file.name)
        args["hyperparameters"] = {
            "config-path": ".",
            "config-name": os.path.basename(cls.training_recipe_file.name),
        }

        return args