def get_training_stage()

in main.py [0:0]


def get_training_stage(cfg):
    """
    Get the right training stage based on the device type and if it is custom training
    """
    instance_type = get_instance_type(cfg)
    is_custom = cfg.get("training_cfg") is not None

    # p and g instances are GPU instances
    if instance_type.startswith(("p", "g")):
        device_type = "gpu"
    elif instance_type.startswith("trn"):
        device_type = "trainium"
    else:
        device_type = "cpu"

    if not is_custom:
        if device_type == "gpu":
            return SMTrainingGPURecipe
        if device_type == "trainium":
            return SMTrainingTrainiumRecipe
        raise ValueError("Recipe only can be run on GPU or Trainium instances")
    else:
        if device_type == "gpu":
            return SMCustomTrainingGPU
        if device_type == "trainium":
            return SMCustomTrainingTrainium
        return SMCustomTrainingCPU