src/sagemaker/pytorch/estimator.py (412 lines of code) (raw):

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Placeholder docstring""" from __future__ import absolute_import import json import logging import math import os import shutil import tempfile from typing import Union, Optional, Dict from urllib.request import urlretrieve import omegaconf from omegaconf import OmegaConf, dictconfig from packaging.version import Version from sagemaker.estimator import Framework, EstimatorBase from sagemaker.fw_utils import ( framework_name_from_image, framework_version_from_tag, python_deprecation_warning, validate_version_or_image_args, validate_distribution, profiler_config_deprecation_warning, ) from sagemaker.git_utils import _run_clone_command from sagemaker.image_uris import retrieve from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig from sagemaker.session import Session from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") def _setup_omegaconf_resolvers(): """Set up omegaconf resolvers for training recipes.""" if not OmegaConf.has_resolver("multiply"): OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) if not OmegaConf.has_resolver("divide_ceil"): OmegaConf.register_new_resolver( "divide_ceil", lambda x, y: int(math.ceil(x / y)), replace=True ) if not OmegaConf.has_resolver("divide_floor"): OmegaConf.register_new_resolver( "divide_floor", lambda x, y: int(math.floor(x / y)), replace=True ) if not OmegaConf.has_resolver("add"): OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers)) def _try_resolve_recipe(recipe, key=None): """Try to resolve recipe and return resolved recipe.""" if key is not None: recipe = dictconfig.DictConfig({key: recipe}) try: OmegaConf.resolve(recipe) except omegaconf.errors.OmegaConfBaseException: return None if key is None: return recipe return recipe[key] def _get_training_recipe_image_uri(image_cfg, region_name): """Fetch image uri given image spec and region name to use for training.""" if isinstance(image_cfg, str): return image_cfg return retrieve( image_cfg.get("framework"), region=region_name, version=image_cfg.get("version"), image_scope="training", **image_cfg.get("additional_args"), ) def _get_training_recipe_gpu_script(code_dir, recipe, source_dir): """Return path to training script (entry point) when running a gpu recipe.""" model_type_to_script = { "llama_v3": ("llama", "llama_pretrain.py"), "mistral": ("mistral", "mistral_pretrain.py"), "mixtral": ("mixtral", "mixtral_pretrain.py"), "deepseek": ("deepseek", "deepseek_pretrain.py"), } if "model" not in recipe: raise ValueError("Supplied recipe does not contain required field model.") if "model_type" not in recipe["model"]: raise ValueError("Supplied recipe does not contain required field model_type.") model_type = recipe["model"]["model_type"] for key in model_type_to_script: if model_type.startswith(key): model_type = key break if model_type not in model_type_to_script: raise ValueError(f"Model type {model_type} not supported") script_dir = os.path.join(code_dir, "examples", model_type_to_script[model_type][0]) script = model_type_to_script[model_type][1] shutil.copyfile(os.path.join(script_dir, script), os.path.join(source_dir, script)) return script def _get_training_recipe_trainium_script(code_dir, source_dir): """Return path to training script (entry point) when running a trainium recipe.""" script_dir = os.path.join(code_dir, "examples") script = "training_orchestrator.py" shutil.copytree(script_dir, source_dir, dirs_exist_ok=True) return script class PyTorch(Framework): """Handle end-to-end training and deployment of custom PyTorch code.""" _framework_name = "pytorch" LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled" LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled" INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type" # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve # to retrieve the image uri below before GA. def __init__( self, entry_point: Optional[Union[str, PipelineVariable]] = None, framework_version: Optional[str] = None, py_version: Optional[str] = None, source_dir: Optional[Union[str, PipelineVariable]] = None, hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, distribution: Optional[Dict] = None, compiler_config: Optional[TrainingCompilerConfig] = None, training_recipe: Optional[str] = None, recipe_overrides: Optional[Dict] = None, **kwargs, ): """This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment. The managed PyTorch environment is an Amazon-built Docker container that executes functions defined in the supplied ``entry_point`` Python script within a SageMaker Training Job. Training is started by calling :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator. After training is complete, calling :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a hosted SageMaker endpoint and returns an :class:`~sagemaker.amazon.pytorch.model.PyTorchPredictor` instance that can be used to perform inference against the hosted model. Technical documentation on preparing PyTorch scripts for SageMaker training and using the PyTorch Estimator is available on the project home-page: https://github.com/aws/sagemaker-python-sdk Args: entry_point (str or PipelineVariable): Path (absolute or relative) to the Python source file which should be executed as the entry point to training. If ``source_dir`` is specified, then ``entry_point`` must point to a file located at the root of ``source_dir``. framework_version (str): PyTorch version you want to use for executing your model training code. Defaults to ``None``. Required unless ``image_uri`` is provided. List of supported versions: https://github.com/aws/deep-learning-containers/blob/master/available_images.md. py_version (str): Python version you want to use for executing your model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required unless ``image_uri`` is provided. source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must point to a file with name ``sourcedir.tar.gz``. Structure within this directory are preserved when training on Amazon SageMaker. Must be a local path when using training_recipe. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for keys and values, but ``str()`` will be called to convert them before training. image_uri (str or PipelineVariable): If specified, the estimator will use this image for training and hosting, instead of selecting the appropriate SageMaker official image based on framework_version and py_version. It can be an ECR url or dockerhub image and tag. Examples: * ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0`` * ``custom-image:latest`` If ``framework_version`` or ``py_version`` are ``None``, then ``image_uri`` is required. If also ``None``, then a ``ValueError`` will be raised. distribution (dict): A dictionary with information on how to configure and run distributed training (default: None). The following options are available. **To enable the SageMaker distributed data parallelism (SMDDP) library:** .. code:: python { "smdistributed": { "dataparallel": { "enabled": True } } } Beside activating the SMDDP library through this parameter, you also need to add few lines of code in your training script for initializing PyTorch Distributed with the SMDDP setups. To learn how to configure your training job with the SMDDP library v2, see `Run distributed training with the SageMaker distributed data parallelism library <https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel.html>`_ in the *Amazon SageMaker User Guide*. **To enable the SageMaker distributed model parallelism (SMP) library v2:** .. code:: python { "torch_distributed": { "enabled": True }, "smdistributed": { "modelparallel": { "enabled": True, "parameters": { "tensor_parallel_degree": 8, "hybrid_shard_degree": 1, ... }, } }, } Beside activating the SMP library v2 through this parameter, you also need to add few lines of code in your training script for initializing PyTorch Distributed with the SMP setups. To learn how to configure your training job with the SMP library v2, see `Run distributed training with the SageMaker model parallelism library v2 <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-v2.html>`_ in the *Amazon SageMaker User Guide*. .. note:: The SageMaker distributed model parallel library v2 requires with ``torch_distributed``. .. note:: The documentation for the SMP library v1.x is archived and available at `Run distributed training with the SageMaker model parallelism library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel.html>`_ in the *Amazon SageMaker User Guide*, and the SMP v1 API reference is available in the `SageMaker Python SDK v2.199.0 documentation <https://sagemaker.readthedocs.io/en/v2.199.0/api/training/distributed.html#the-sagemaker-distributed-model-parallel-library>`_. **To enable PyTorch DDP:** .. code:: python { "pytorchddp": { "enabled": True } } To learn more, see `Distributed PyTorch Training <https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training>`_. **To enable Torch Distributed:** This is available for general distributed training on GPU instances from PyTorch v1.13.1 and later. .. code:: python { "torch_distributed": { "enabled": True } } This option also supports distributed training on Trn1. To learn more, see `Distributed PyTorch Training on Trainium <https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training-on-trainium>`_. **To enable MPI:** .. code:: python { "mpi": { "enabled": True } } To learn more, see `Training with Horovod <https://sagemaker.readthedocs.io/en/stable/frameworks/tensorflow/using_tf.html#training-with-horovod>`_. **To enable parameter server:** .. code:: python { "parameter_server": { "enabled": True } } To learn more, see `Training with parameter servers <https://sagemaker.readthedocs.io/en/stable/frameworks/tensorflow/using_tf.html#training-with-parameter-servers>`_. **To enable distributed training with SageMaker Training Compiler:** .. code:: python { "pytorchxla": { "enabled": True } } To learn more, see `SageMaker Training Compiler <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_ in the *Amazon SageMaker Developer Guide*. .. note:: When you use this PyTorch XLA option for distributed training strategy, you must add the ``compiler_config`` parameter and activate SageMaker Training Compiler. compiler_config (:class:`~sagemaker.pytorch.TrainingCompilerConfig`): Configures SageMaker Training Compiler to accelerate training. training_recipe (str): Training recipe to use. This is a local file path, a url, or a recipe provided by Amazon SageMaker HyperPod recipes, such as training/llama/hf_llama3_70b_seq8k_gpu_p5x64_pretrain. This is required when using recipes. recipe_overrides (Dict): Dictionary specifying key values to override in the training_recipe. This is optional when using Amazon SageMaker HyperPod recipes. **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. .. tip:: You can find additional parameters for initializing this class at :class:`~sagemaker.estimator.Framework` and :class:`~sagemaker.estimator.EstimatorBase`. """ if training_recipe is not None: if entry_point is not None: logger.warning("Argument entry_point will be ignored with training_recipe.") if hyperparameters is not None: logger.warning("Argument hyperparameters will be ignored with training recipe.") if distribution is not None: logger.warning("Argument distribution will be ignored with training_recipe.") args = self._setup_for_training_recipe( training_recipe, recipe_overrides, source_dir, kwargs ) entry_point = args["entry_point"] source_dir = args["source_dir"] hyperparameters = args["hyperparameters"] if image_uri is None: image_uri = args["default_image_uri"] distribution = args["distribution"] elif entry_point is None: raise ValueError( "Argument entry_point must be set when training_recipe is not provided" ) validate_version_or_image_args(framework_version, py_version, image_uri) if py_version == "py2": logger.warning( python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION) ) self.framework_version = framework_version self.py_version = py_version if "enable_sagemaker_metrics" not in kwargs: # enable sagemaker metrics for PT v1.3 or greater: if self.framework_version and Version(self.framework_version) >= Version("1.3"): kwargs["enable_sagemaker_metrics"] = True super(PyTorch, self).__init__( entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs ) if "entry_point" not in kwargs: kwargs["entry_point"] = entry_point if distribution is not None: # rewrite pytorchddp to smdistributed if "pytorchddp" in distribution: if "smdistributed" in distribution: raise ValueError( "Cannot use both pytorchddp and smdistributed " "distribution options together.", distribution, ) # convert pytorchddp distribution into smdistributed distribution distribution = distribution.copy() distribution["smdistributed"] = {"dataparallel": distribution["pytorchddp"]} del distribution["pytorchddp"] distribution = validate_distribution( distribution, self.instance_groups, self._framework_name, framework_version, py_version, image_uri, kwargs, ) self.distribution = distribution or {} if compiler_config is not None: if not isinstance(compiler_config, TrainingCompilerConfig): error_string = ( f"Expected instance of type {TrainingCompilerConfig}" f"for argument compiler_config. " f"Instead got {type(compiler_config)}" ) raise ValueError(error_string) if compiler_config: compiler_config.validate(self) elif distribution is not None and "pytorchxla" in distribution: raise ValueError( "Distributed training through PyTorch XLA is currently only supported " "when SageMaker Training Compiler is enabled. To learn more, " "see Enable SageMaker Training Compiler at " "https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html." ) self.compiler_config = compiler_config if "profiler_config" in kwargs: profiler_config_deprecation_warning( kwargs["profiler_config"], image_uri, self._framework_name, framework_version ) def _pytorch_distribution_configuration(self, distribution): """Returns a dict of distribution config for PyTorch training Args: distribution (dict): A dictionary with information on how to run distributed training. Returns: dict containing Pytorch DDP config """ distribution_config = {} pytorch_ddp_enabled = False torch_distributed_enabled = False if "pytorchddp" in distribution: pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False) elif "torch_distributed" in distribution: torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False) if pytorch_ddp_enabled: distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled if self.instance_type is not None: distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type elif torch_distributed_enabled: if "smdistributed" in distribution: # Enable torch_distributed for smdistributed. distribution_config = self._distribution_configuration(distribution=distribution) distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled if self.instance_type is not None: distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type else: distribution_config = self._distribution_configuration(distribution=distribution) return distribution_config def hyperparameters(self): """Return hyperparameters used by your custom PyTorch code during model training.""" hyperparameters = super(PyTorch, self).hyperparameters() additional_hyperparameters = self._pytorch_distribution_configuration( distribution=self.distribution ) hyperparameters.update( EstimatorBase._json_encode_hyperparameters(additional_hyperparameters) ) if self.compiler_config: training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict() hyperparameters.update( EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters) ) return hyperparameters def create_model( self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT, entry_point=None, source_dir=None, dependencies=None, **kwargs, ): """Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``. Args: model_server_workers (int): Optional. The number of worker processes used by the inference server. If None, server will use one worker per vCPU. role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during transform jobs. If not specified, the role from the Estimator will be used. vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on the model. Default: use subnets and security groups from this Estimator. * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. entry_point (str): Path (absolute or relative) to the local Python source file which should be executed as the entry point to training. If ``source_dir`` is specified, then ``entry_point`` must point to a file located at the root of ``source_dir``. If not specified, the training entry point is used. source_dir (str): Path (absolute or relative) to a directory with any other serving source code dependencies aside from the entry point file. If not specified, the model source directory from training is used. dependencies (list[str]): A list of paths to directories (absolute or relative) with any additional libraries that will be exported to the container. If not specified, the dependencies from training are used. This is not supported with "local code" in Local Mode. **kwargs: Additional kwargs passed to the :class:`~sagemaker.pytorch.model.PyTorchModel` constructor. Returns: sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel`` object. See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details. """ if "image_uri" not in kwargs: kwargs["image_uri"] = self.image_uri kwargs["name"] = self._get_or_create_name(kwargs.get("name")) return PyTorchModel( self.model_data, role or self.role, entry_point or self._model_entry_point(), framework_version=self.framework_version, py_version=self.py_version, source_dir=(source_dir or self._model_source_dir()), container_log_level=self.container_log_level, code_location=self.code_location, model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), dependencies=(dependencies or self.dependencies), **kwargs, ) @classmethod def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): """Convert the job description to init params that can be handled by the class constructor. Args: job_details: the returned job details from a describe_training_job API call. model_channel_name (str): Name of the channel where pre-trained model data will be downloaded. Returns: dictionary: The transformed init_params """ init_params = super(PyTorch, cls)._prepare_init_params_from_job_description( job_details, model_channel_name ) image_uri = init_params.pop("image_uri") framework, py_version, tag, _ = framework_name_from_image(image_uri) if framework: framework = framework.split("-")[0] if tag is None: framework_version = None else: framework_version = framework_version_from_tag(tag) init_params["framework_version"] = framework_version init_params["py_version"] = py_version if not framework: # If we were unable to parse the framework name from the image it is not one of our # officially supported images, in this case just add the image to the init params. init_params["image_uri"] = image_uri return init_params if framework != cls._framework_name: raise ValueError( "Training job: {} didn't use image for requested framework".format( job_details["TrainingJobName"] ) ) return init_params @classmethod def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, source_dir, kwargs): """Performs training recipe specific setup and returns recipe specific args. Updates kwargs and returns a dictionary of args to use for estimator initialization and setup when using a training recipe. Updates the paths in the recipe for Sagemaker Jobs environment. Args: training_recipe (str): A recipe which is a local file path, a url or a sagemaker training recipe. recipe_overrides (Dict): Dictionary specifying key values to override in the source_dir (str): Path (absolute, or relative) to a directory where to copy the scripts for training recipe. requirements.txt can also go here. kwargs (dict): Dictionary of args used for estimator initializaiton. Returns: dict containing arg values for estimator initialization and setup. """ if kwargs.get("sagemaker_session") is not None: region_name = kwargs.get("sagemaker_session").boto_region_name else: region_name = Session().boto_region_name training_recipes_cfg_filename = os.path.join( os.path.dirname(__file__), "training_recipes.json" ) with open(training_recipes_cfg_filename) as training_recipes_cfg_file: training_recipes_cfg = json.load(training_recipes_cfg_file) if recipe_overrides is None: recipe_overrides = dict() recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_") recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_") args = dict() if source_dir is None: args["source_dir"] = "." else: if not os.path.exists(source_dir): raise ValueError( "When using training_recipe, source_dir must be a local directory." ) args["source_dir"] = source_dir recipe_name = os.path.splitext(os.path.basename(training_recipe))[0] temp_local_recipe = tempfile.NamedTemporaryFile(prefix=recipe_name, suffix=".yaml").name if training_recipe.endswith(".yaml"): if os.path.isfile(training_recipe): shutil.copy(training_recipe, temp_local_recipe) else: try: urlretrieve(training_recipe, temp_local_recipe) except Exception as e: raise ValueError( f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}" ) else: launcher_repo = os.environ.get( "TRAINING_LAUNCHER_GIT", None ) or training_recipes_cfg.get("launcher_repo") _run_clone_command(launcher_repo, recipe_launcher_dir.name) recipe = os.path.join( recipe_launcher_dir.name, "recipes_collection", "recipes", training_recipe + ".yaml", ) if os.path.isfile(recipe): shutil.copy(recipe, temp_local_recipe) else: raise ValueError(f"Recipe {training_recipe} not found.") recipe = OmegaConf.load(temp_local_recipe) os.unlink(temp_local_recipe) recipe = OmegaConf.merge(recipe, recipe_overrides) if "instance_type" not in kwargs: raise ValueError("Must pass instance type to estimator when using training recipes.") instance_type = kwargs["instance_type"].split(".")[1] if instance_type.startswith(("p", "g")): device_type = "gpu" elif instance_type.startswith("trn"): device_type = "trainium" else: device_type = "cpu" if "trainer" not in recipe: raise ValueError("Supplied recipe does not contain required field trainer.") if "instance_count" in kwargs and "num_nodes" in recipe["trainer"]: logger.warning( "Using instance_count argument to estimator to set number " " of nodes. Ignoring trainer -> num_nodes in recipe." ) if "instance_count" not in kwargs: if "num_nodes" not in recipe["trainer"]: raise ValueError( "Must set either instance_count argument for estimator or" "set trainer -> num_nodes in recipe." ) kwargs["instance_count"] = recipe["trainer"]["num_nodes"] # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve # to retrieve the image uri below before we go GA. if device_type == "gpu": adapter_repo = os.environ.get("TRAINING_ADAPTER_GIT", None) or training_recipes_cfg.get( "adapter_repo" ) _run_clone_command(adapter_repo, recipe_train_dir.name) script = _get_training_recipe_gpu_script( recipe_train_dir.name, recipe, args["source_dir"] ) args["default_image_uri"] = _get_training_recipe_image_uri( training_recipes_cfg.get("gpu_image"), region_name ) smp_options = { "enabled": True, "parameters": { "placement_strategy": "cluster", }, } args["distribution"] = { "smdistributed": {"modelparallel": smp_options}, "torch_distributed": {"enabled": True}, } elif device_type == "trainium": _run_clone_command(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir.name) script = _get_training_recipe_trainium_script(recipe_train_dir.name, args["source_dir"]) args["default_image_uri"] = _get_training_recipe_image_uri( training_recipes_cfg.get("neuron_image"), region_name ) args["distribution"] = { "torch_distributed": {"enabled": True}, } else: raise ValueError( f"Devices of type {device_type} are not supported with training recipes." ) args["entry_point"] = os.path.basename(script) recipe_train_dir.cleanup() recipe_launcher_dir.cleanup() if "container" in recipe and not recipe["container"]: logger.warning( "Ignoring container from training_recipe. Use image_uri arg for estimator." ) _setup_omegaconf_resolvers() final_recipe = _try_resolve_recipe(recipe) if final_recipe is None: final_recipe = _try_resolve_recipe(recipe, "recipes") if final_recipe is None: final_recipe = _try_resolve_recipe(recipe, "training") if final_recipe is None: raise RuntimeError("Could not resolve provided recipe.") cls.training_recipe_file = tempfile.NamedTemporaryFile( dir=args["source_dir"], prefix=recipe_name + "_", suffix=".yaml", ) OmegaConf.save(config=final_recipe, f=cls.training_recipe_file.name) args["hyperparameters"] = { "config-path": ".", "config-name": os.path.basename(cls.training_recipe_file.name), } return args