community-content/vertex_cpr_samples/xgboost/predictor_xgbranker.py (34 lines of code) (raw):

import os import numpy as np import pandas as pd import pickle import xgboost as xgb from google.cloud.aiplatform.constants import prediction from google.cloud.aiplatform.utils import prediction_utils from google.cloud.aiplatform.prediction.predictor import Predictor class XGBRankerPredictor(Predictor): def __init__(self): return def load(self, artifacts_uri: str) -> None: prediction_utils.download_model_artifacts(artifacts_uri) if os.path.exists(prediction.MODEL_FILENAME_PKL): booster = pickle.load(open(prediction.MODEL_FILENAME_PKL, "rb")) self._booster = booster else: N = 500 dates = pd.date_range(start='2023-01-01', end='2023-01-12', periods=N) X = pd.DataFrame(np.random.randn(N, 5), columns=list('ABCDE'), index=dates) y = pd.Series(np.random.randint(0, 10, size=N), index=dates, name='label') group = X.groupby(dates + pd.offsets.MonthEnd(0)).size() sample_weight = pd.Series(np.arange(len(group)), index=group.index) model = xgb.XGBRanker(objective='rank:pairwise', max_depth=3, learning_rate=0.1, booster='gbtree', tree_method='hist', n_jobs=4, n_estimators=50, enable_categorical=False, random_state=42) model.fit(X=X, y=y, group=group, sample_weight=sample_weight, verbose=True) booster = model.get_booster() self._booster = booster def preprocess(self, prediction_input: dict) -> xgb.DMatrix: instances = prediction_input["instances"] return xgb.DMatrix(instances) def predict(self, instances: xgb.DMatrix) -> np.ndarray: return self._booster.predict(instances, output_margin=False, ntree_limit=0) def postprocess(self, prediction_results: np.ndarray) -> dict: return {"predictions": prediction_results.tolist()}