def get_training_func()

in Project-BasicAlgorithm/train.py [0:0]


def get_training_func(algorithm):
    if algorithm == "svm":
        from core.training.svm import train_svc as training_func

    elif algorithm == "lightgbm":
        from core.training.lightgbm import train_lightgbm as training_func

    elif algorithm == "xgboost":
        from core.training.xgboost import train_xgboost as training_func

    elif algorithm == "lr":
        from core.training.lr import train_lr as training_func

    else:
        assert f"{algorithm} not supported"

    return training_func