src/sagemaker/modules/train/sm_recipes/utils.py (226 lines of code) (raw):

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Utility functions for SageMaker training recipes.""" from __future__ import absolute_import import math import os import json import shutil import tempfile from urllib.request import urlretrieve from typing import Dict, Any, Optional, Tuple import omegaconf from omegaconf import OmegaConf, dictconfig from sagemaker.image_uris import retrieve from sagemaker.modules import logger from sagemaker.modules.utils import _run_clone_command_silent from sagemaker.modules.configs import Compute, SourceCode from sagemaker.modules.distributed import Torchrun, SMP def _try_resolve_recipe(recipe, key=None): """Try to resolve recipe and return resolved recipe.""" if key is not None: recipe = dictconfig.DictConfig({key: recipe}) try: OmegaConf.resolve(recipe) except omegaconf.errors.OmegaConfBaseException: return None if key is None: return recipe return recipe[key] def _determine_device_type(instance_type: str) -> str: """Determine device type (gpu, cpu, trainium) based on instance type.""" instance_family = instance_type.split(".")[1] if instance_family.startswith(("p", "g")): return "gpu" if instance_family.startswith("trn"): return "trainium" return "cpu" def _load_recipes_cfg() -> str: """Load training recipes configuration json.""" training_recipes_cfg_filename = os.path.join(os.path.dirname(__file__), "training_recipes.json") with open(training_recipes_cfg_filename) as training_recipes_cfg_file: training_recipes_cfg = json.load(training_recipes_cfg_file) return training_recipes_cfg def _load_base_recipe( training_recipe: str, recipe_overrides: Optional[Dict[str, Any]] = None, training_recipes_cfg: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Load recipe and apply overrides.""" if recipe_overrides is None: recipe_overrides = dict() temp_local_recipe = tempfile.NamedTemporaryFile(prefix="recipe_original", suffix=".yaml").name if training_recipe.endswith(".yaml"): if os.path.isfile(training_recipe): shutil.copy(training_recipe, temp_local_recipe) else: try: urlretrieve(training_recipe, temp_local_recipe) except Exception as e: raise ValueError( f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}" ) else: recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_") launcher_repo = os.environ.get("TRAINING_LAUNCHER_GIT", None) or training_recipes_cfg.get( "launcher_repo" ) _run_clone_command_silent(launcher_repo, recipe_launcher_dir.name) recipe = os.path.join( recipe_launcher_dir.name, "recipes_collection", "recipes", training_recipe + ".yaml", ) if os.path.isfile(recipe): shutil.copy(recipe, temp_local_recipe) else: raise ValueError(f"Recipe {training_recipe} not found.") recipe = OmegaConf.load(temp_local_recipe) os.unlink(temp_local_recipe) recipe = OmegaConf.merge(recipe, recipe_overrides) return recipe def _register_custom_resolvers(): """Register custom resolvers for OmegaConf.""" if not OmegaConf.has_resolver("multiply"): OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) if not OmegaConf.has_resolver("divide_ceil"): OmegaConf.register_new_resolver( "divide_ceil", lambda x, y: int(math.ceil(x / y)), replace=True ) if not OmegaConf.has_resolver("divide_floor"): OmegaConf.register_new_resolver( "divide_floor", lambda x, y: int(math.floor(x / y)), replace=True ) if not OmegaConf.has_resolver("add"): OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers)) def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): """Get the model base name and script for the training recipe.""" model_type_to_script = { "llama_v3": ("llama", "llama_pretrain.py"), "mistral": ("mistral", "mistral_pretrain.py"), "mixtral": ("mixtral", "mixtral_pretrain.py"), "deepseek": ("deepseek", "deepseek_pretrain.py"), } for key in model_type_to_script: if model_type.startswith(key): model_type = key break if model_type not in model_type_to_script: raise ValueError(f"Model type {model_type} not supported") return model_type_to_script[model_type][0], model_type_to_script[model_type][1] def _configure_gpu_args( training_recipes_cfg: Dict[str, Any], region_name: str, recipe: OmegaConf, recipe_train_dir: tempfile.TemporaryDirectory, ) -> Dict[str, Any]: """Configure arguments specific to GPU.""" source_code = SourceCode() args = dict() adapter_repo = os.environ.get("TRAINING_ADAPTER_GIT", None) or training_recipes_cfg.get( "adapter_repo" ) _run_clone_command_silent(adapter_repo, recipe_train_dir.name) if "model" not in recipe: raise ValueError("Supplied recipe does not contain required field model.") if "model_type" not in recipe["model"]: raise ValueError("Supplied recipe does not contain required field model_type.") model_type = recipe["model"]["model_type"] model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type) source_code.source_dir = os.path.join(recipe_train_dir.name, "examples", model_base_name) source_code.entry_script = script gpu_image_cfg = training_recipes_cfg.get("gpu_image") if isinstance(gpu_image_cfg, str): training_image = gpu_image_cfg else: training_image = retrieve( gpu_image_cfg.get("framework"), region=region_name, version=gpu_image_cfg.get("version"), image_scope="training", **gpu_image_cfg.get("additional_args"), ) # Setting dummy parameters for now torch_distributed = Torchrun(smp=SMP(random_seed="123456")) args.update( { "source_code": source_code, "training_image": training_image, "distributed": torch_distributed, } ) return args def _configure_trainium_args( training_recipes_cfg: Dict[str, Any], region_name: str, recipe_train_dir: tempfile.TemporaryDirectory, ) -> Dict[str, Any]: """Configure arguments specific to Trainium.""" source_code = SourceCode() args = dict() _run_clone_command_silent(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir.name) source_code.source_dir = os.path.join(recipe_train_dir.name, "examples") source_code.entry_script = "training_orchestrator.py" neuron_image_cfg = training_recipes_cfg.get("neuron_image") if isinstance(neuron_image_cfg, str): training_image = neuron_image_cfg else: training_image = retrieve( neuron_image_cfg.get("framework"), region=region_name, version=neuron_image_cfg.get("version"), image_scope="training", **neuron_image_cfg.get("additional_args"), ) args.update( { "source_code": source_code, "training_image": training_image, "distributed": Torchrun(), } ) return args def _get_args_from_recipe( training_recipe: str, compute: Compute, region_name: str, recipe_overrides: Optional[Dict[str, Any]], requirements: Optional[str], ) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]: """Get arguments for ModelTrainer from a training recipe. Returns a dictionary of arguments to be used with ModelTrainer like: ```python { "source_code": SourceCode, "training_image": str, "distributed": DistributedConfig, "compute": Compute, "hyperparameters": Dict[str, Any], } ``` Args: training_recipe (str): Name of the training recipe or path to the recipe file. compute (Compute): Compute configuration for training. region_name (str): Name of the AWS region. recipe_overrides (Optional[Dict[str, Any]]): Overrides for the training recipe. requirements (Optional[str]): Path to the requirements file. """ if compute.instance_type is None: raise ValueError("Must set `instance_type` in compute when using training recipes.") training_recipes_cfg = _load_recipes_cfg() recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg) if "trainer" not in recipe: raise ValueError("Supplied recipe does not contain required field trainer.") # Set instance_count if compute.instance_count and "num_nodes" in recipe["trainer"]: logger.warning( f"Using Compute to set instance_count:\n{compute}." "\nIgnoring trainer -> num_nodes in recipe." ) if compute.instance_count is None: if "num_nodes" not in recipe["trainer"]: raise ValueError( "Must provide Compute with instance_count or" " set trainer -> num_nodes in recipe." ) compute.instance_count = recipe["trainer"]["num_nodes"] if requirements and not os.path.isfile(requirements): raise ValueError(f"Recipe requirements file {requirements} not found.") # Get Training Image, SourceCode, and distributed args device_type = _determine_device_type(compute.instance_type) recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_") if device_type == "gpu": args = _configure_gpu_args(training_recipes_cfg, region_name, recipe, recipe_train_dir) elif device_type == "trainium": args = _configure_trainium_args(training_recipes_cfg, region_name, recipe_train_dir) else: raise ValueError(f"Devices of type {device_type} are not supported with training recipes.") _register_custom_resolvers() # Resolve Final Recipe final_recipe = _try_resolve_recipe(recipe) if final_recipe is None: final_recipe = _try_resolve_recipe(recipe, "recipes") if final_recipe is None: final_recipe = _try_resolve_recipe(recipe, "training") if final_recipe is None: raise RuntimeError("Could not resolve provided recipe.") # Save Final Recipe to source_dir OmegaConf.save( config=final_recipe, f=os.path.join(args["source_code"].source_dir, "recipe.yaml") ) # If recipe_requirements is provided, copy it to source_dir if requirements: shutil.copy(requirements, args["source_code"].source_dir) args["source_code"].requirements = os.path.basename(requirements) # Update args with compute and hyperparameters args.update( { "compute": compute, "hyperparameters": {"config-path": ".", "config-name": "recipe.yaml"}, } ) return args, recipe_train_dir