in src/sagemaker/pytorch/estimator.py [0:0]
def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, source_dir, kwargs):
"""Performs training recipe specific setup and returns recipe specific args.
Updates kwargs and returns a dictionary of args to use for estimator
initialization and setup when using a training recipe. Updates the paths in
the recipe for Sagemaker Jobs environment.
Args:
training_recipe (str): A recipe which is a local file path, a url or a
sagemaker training recipe.
recipe_overrides (Dict): Dictionary specifying key values to override in the
source_dir (str): Path (absolute, or relative) to a directory where to copy
the scripts for training recipe. requirements.txt can also
go here.
kwargs (dict): Dictionary of args used for estimator initializaiton.
Returns:
dict containing arg values for estimator initialization and setup.
"""
if kwargs.get("sagemaker_session") is not None:
region_name = kwargs.get("sagemaker_session").boto_region_name
else:
region_name = Session().boto_region_name
training_recipes_cfg_filename = os.path.join(
os.path.dirname(__file__), "training_recipes.json"
)
with open(training_recipes_cfg_filename) as training_recipes_cfg_file:
training_recipes_cfg = json.load(training_recipes_cfg_file)
if recipe_overrides is None:
recipe_overrides = dict()
recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_")
recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_")
args = dict()
if source_dir is None:
args["source_dir"] = "."
else:
if not os.path.exists(source_dir):
raise ValueError(
"When using training_recipe, source_dir must be a local directory."
)
args["source_dir"] = source_dir
recipe_name = os.path.splitext(os.path.basename(training_recipe))[0]
temp_local_recipe = tempfile.NamedTemporaryFile(prefix=recipe_name, suffix=".yaml").name
if training_recipe.endswith(".yaml"):
if os.path.isfile(training_recipe):
shutil.copy(training_recipe, temp_local_recipe)
else:
try:
urlretrieve(training_recipe, temp_local_recipe)
except Exception as e:
raise ValueError(
f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}"
)
else:
launcher_repo = os.environ.get(
"TRAINING_LAUNCHER_GIT", None
) or training_recipes_cfg.get("launcher_repo")
_run_clone_command(launcher_repo, recipe_launcher_dir.name)
recipe = os.path.join(
recipe_launcher_dir.name,
"recipes_collection",
"recipes",
training_recipe + ".yaml",
)
if os.path.isfile(recipe):
shutil.copy(recipe, temp_local_recipe)
else:
raise ValueError(f"Recipe {training_recipe} not found.")
recipe = OmegaConf.load(temp_local_recipe)
os.unlink(temp_local_recipe)
recipe = OmegaConf.merge(recipe, recipe_overrides)
if "instance_type" not in kwargs:
raise ValueError("Must pass instance type to estimator when using training recipes.")
instance_type = kwargs["instance_type"].split(".")[1]
if instance_type.startswith(("p", "g")):
device_type = "gpu"
elif instance_type.startswith("trn"):
device_type = "trainium"
else:
device_type = "cpu"
if "trainer" not in recipe:
raise ValueError("Supplied recipe does not contain required field trainer.")
if "instance_count" in kwargs and "num_nodes" in recipe["trainer"]:
logger.warning(
"Using instance_count argument to estimator to set number "
" of nodes. Ignoring trainer -> num_nodes in recipe."
)
if "instance_count" not in kwargs:
if "num_nodes" not in recipe["trainer"]:
raise ValueError(
"Must set either instance_count argument for estimator or"
"set trainer -> num_nodes in recipe."
)
kwargs["instance_count"] = recipe["trainer"]["num_nodes"]
# [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
# to retrieve the image uri below before we go GA.
if device_type == "gpu":
adapter_repo = os.environ.get("TRAINING_ADAPTER_GIT", None) or training_recipes_cfg.get(
"adapter_repo"
)
_run_clone_command(adapter_repo, recipe_train_dir.name)
script = _get_training_recipe_gpu_script(
recipe_train_dir.name, recipe, args["source_dir"]
)
args["default_image_uri"] = _get_training_recipe_image_uri(
training_recipes_cfg.get("gpu_image"), region_name
)
smp_options = {
"enabled": True,
"parameters": {
"placement_strategy": "cluster",
},
}
args["distribution"] = {
"smdistributed": {"modelparallel": smp_options},
"torch_distributed": {"enabled": True},
}
elif device_type == "trainium":
_run_clone_command(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir.name)
script = _get_training_recipe_trainium_script(recipe_train_dir.name, args["source_dir"])
args["default_image_uri"] = _get_training_recipe_image_uri(
training_recipes_cfg.get("neuron_image"), region_name
)
args["distribution"] = {
"torch_distributed": {"enabled": True},
}
else:
raise ValueError(
f"Devices of type {device_type} are not supported with training recipes."
)
args["entry_point"] = os.path.basename(script)
recipe_train_dir.cleanup()
recipe_launcher_dir.cleanup()
if "container" in recipe and not recipe["container"]:
logger.warning(
"Ignoring container from training_recipe. Use image_uri arg for estimator."
)
_setup_omegaconf_resolvers()
final_recipe = _try_resolve_recipe(recipe)
if final_recipe is None:
final_recipe = _try_resolve_recipe(recipe, "recipes")
if final_recipe is None:
final_recipe = _try_resolve_recipe(recipe, "training")
if final_recipe is None:
raise RuntimeError("Could not resolve provided recipe.")
cls.training_recipe_file = tempfile.NamedTemporaryFile(
dir=args["source_dir"],
prefix=recipe_name + "_",
suffix=".yaml",
)
OmegaConf.save(config=final_recipe, f=cls.training_recipe_file.name)
args["hyperparameters"] = {
"config-path": ".",
"config-name": os.path.basename(cls.training_recipe_file.name),
}
return args