in 6-Pipelines/config/xgboost_customer_churn.py [0:0]
def model_fn(model_dir):
"""Load a model. For XGBoost Framework, a default function to load a model is not provided.
Users should provide customized model_fn() in script.
Args:
model_dir: a directory where model is saved.
Returns:
A XGBoost model.
XGBoost model format type.
"""
model_files = (file for file in os.listdir(model_dir) if os.path.isfile(os.path.join(model_dir, file)))
model_file = next(model_files)
try:
booster = pickle.load(open(os.path.join(model_dir, model_file), 'rb'))
format = 'pkl_format'
except Exception as exp_pkl:
try:
booster = xgboost.Booster()
booster.load_model(os.path.join(model_dir, model_file))
format = 'xgb_format'
except Exception as exp_xgb:
raise ModelLoadInferenceError("Unable to load model: {} {}".format(str(exp_pkl), str(exp_xgb)))
booster.set_param('nthread', 1)
return booster, format