def process()

in tfx_addons/xgboost_evaluator/xgboost_predict_extractor.py [0:0]


  def process(self, elem: tfma.Extracts) -> Iterable[tfma.Extracts]:
    """Uses loaded models to make predictions on batches of data.

    Args:
      elem: An extract containing batched features.

    Yields:
      Copy of the original extracts with predictions added for each model. If
      there are multiple models, a list of dicts keyed on model names will be
      added, with each value corresponding to a prediction for a single sample.
    """
    # Build feature and label vectors because xgboost cannot read tf.Examples.
    features = []
    labels = []
    result = copy.copy(elem)
    for features_dict in result[tfma.FEATURES_KEY]:
      features_row = [features_dict[key] for key in self._feature_keys]
      features.append(np.concatenate(features_row))
      labels.append(features_dict[self._label_key])
    result[tfma.LABELS_KEY] = np.concatenate(labels)
    features = xgb.DMatrix(pd.DataFrame(features, columns=self._feature_keys))

    # Generate predictions for each model.
    for model_name, loaded_model in self._loaded_models.items():
      preds = loaded_model.predict(features)
      if len(self._loaded_models) == 1:
        result[tfma.PREDICTIONS_KEY] = preds
      elif tfma.PREDICTIONS_KEY not in result:
        result[tfma.PREDICTIONS_KEY] = [{model_name: pred} for pred in preds]
      else:
        for i, pred in enumerate(preds):
          result[tfma.PREDICTIONS_KEY][i][model_name] = pred
    yield result