def get_loaded_booster()

in src/sagemaker_xgboost_container/algorithm_mode/serve_utils.py [0:0]


def get_loaded_booster(model_dir, ensemble=False):
    full_model_paths = list(_get_full_model_paths(model_dir))
    full_model_paths = full_model_paths if ensemble else full_model_paths[0:1]

    models = []
    model_formats = []
    for full_model_path in full_model_paths:
        logging.info(f"Loading the model from {full_model_path}")
        try:
            booster = pkl.load(open(full_model_path, "rb"))
            model_format = PKL_FORMAT
        except Exception as exp_pkl:
            try:
                booster = xgb.Booster()
                booster.load_model(full_model_path)
                model_format = XGB_FORMAT
            except Exception as exp_xgb:
                raise RuntimeError(
                    f"Model {full_model_path} cannot be loaded:"
                    f"\nPickle load error={str(exp_pkl)}"
                    f"\nXGB load model error={str(exp_xgb)}"
                )
        booster.set_param("nthread", 1)
        models.append(booster)
        model_formats.append(model_format)

    return (models, model_formats) if ensemble and len(models) > 1 else (models[0], model_formats[0])