in main.py [0:0]
def get_training_stage(cfg):
"""
Get the right training stage based on the device type and if it is custom training
"""
instance_type = get_instance_type(cfg)
is_custom = cfg.get("training_cfg") is not None
# p and g instances are GPU instances
if instance_type.startswith(("p", "g")):
device_type = "gpu"
elif instance_type.startswith("trn"):
device_type = "trainium"
else:
device_type = "cpu"
if not is_custom:
if device_type == "gpu":
return SMTrainingGPURecipe
if device_type == "trainium":
return SMTrainingTrainiumRecipe
raise ValueError("Recipe only can be run on GPU or Trainium instances")
else:
if device_type == "gpu":
return SMCustomTrainingGPU
if device_type == "trainium":
return SMCustomTrainingTrainium
return SMCustomTrainingCPU