in src/sagemaker_xgboost_container/algorithm_mode/serve_utils.py [0:0]
def get_loaded_booster(model_dir, ensemble=False):
full_model_paths = list(_get_full_model_paths(model_dir))
full_model_paths = full_model_paths if ensemble else full_model_paths[0:1]
models = []
model_formats = []
for full_model_path in full_model_paths:
logging.info(f"Loading the model from {full_model_path}")
try:
booster = pkl.load(open(full_model_path, "rb"))
model_format = PKL_FORMAT
except Exception as exp_pkl:
try:
booster = xgb.Booster()
booster.load_model(full_model_path)
model_format = XGB_FORMAT
except Exception as exp_xgb:
raise RuntimeError(
f"Model {full_model_path} cannot be loaded:"
f"\nPickle load error={str(exp_pkl)}"
f"\nXGB load model error={str(exp_xgb)}"
)
booster.set_param("nthread", 1)
models.append(booster)
model_formats.append(model_format)
return (models, model_formats) if ensemble and len(models) > 1 else (models[0], model_formats[0])