def _extract_model_metadata()

in inference/xgboost_predictor/predictor.py [0:0]


  def _extract_model_metadata(self):
    """Extracts info from model metadata and fills member variables.

    Raises:
      ValueError: An error occurred when:
        1. Invalid model type.
        2. Label not found.
        3. Features not found.
        4. Class names not found for {boosted_tree|random_forest}_classifier.
        6. Feature index mismatch.
        7. Invalid encode type for categorical features.
    """
    if 'model_type' not in self._model_metadata or self._model_metadata[
        'model_type'] not in [
            'boosted_tree_regressor', 'boosted_tree_classifier',
            'random_forest_regressor', 'random_forest_classifier',
            'boosted_tree_cox', 'boosted_tree_aft'
        ]:
      raise ValueError('Invalid model_type in model_metadata')
    self._model_type = self._model_metadata['model_type']
    if 'label_col' not in self._model_metadata:
      raise ValueError('label_col not found in model_metadata')
    self._label_col = self._model_metadata['label_col']
    if not self._model_metadata['features']:
      raise ValueError('No feature found in model_metadata')
    self._feature_names = self._model_metadata['feature_names']
    if self._model_type in ['boosted_tree_classifier',
                            'random_forest_classifier']:
      if 'class_names' not in self._model_metadata or not self._model_metadata[
          'class_names']:
        raise ValueError('No class_names found in model_metadata')
      self._class_names = self._model_metadata['class_names']
    for feature_index in range(len(self._feature_names)):
      feature_name = self._feature_names[feature_index]
      self._feature_name_to_index_map[feature_name] = feature_index
      feature_metadata = self._model_metadata['features'][feature_name]
      if ('encode_type' not in feature_metadata) or (not feature_metadata[
          'encode_type']) or (feature_metadata[
              'encode_type'] == 'numerical_identity'):
        continue
      elif feature_metadata['encode_type'] == 'categorical_one_hot':
        if feature_index not in self._categorical_one_hot_vocab:
          raise ValueError(
              'feature_index %d missing in _categorical_one_hot_vocab' %
              feature_index)
      elif feature_metadata['encode_type'] == 'categorical_target':
        if feature_index not in self._categorical_target_vocab:
          raise ValueError(
              'feature_index %d missing in _categorical_target_vocab' %
              feature_index)
      elif (feature_metadata[
          'encode_type'] == 'categorical_label') or (feature_metadata[
              'encode_type'] == 'ohe'):
        if feature_index not in self._categorical_label_vocab:
          raise ValueError(
              'feature_index %d missing in _categorical_label_vocab' %
              feature_index)
      elif (feature_metadata[
          'encode_type'] == 'array_one_hot') or (feature_metadata[
              'encode_type'] == 'mhe'):
        if feature_index not in self._array_one_hot_vocab:
          raise ValueError('feature_index %d missing in _array_one_hot_vocab' %
                           feature_index)
      elif feature_metadata['encode_type'] == 'array_target':
        if feature_index not in self._array_target_vocab:
          raise ValueError('feature_index %d missing in _array_target_vocab' %
                           feature_index)
      elif feature_metadata['encode_type'] == 'array_struct':
        if (self._array_struct_dimension_dict and
            feature_index not in self._array_struct_dimension_dict):
          raise ValueError(
              'feature_index %d missing in _array_struct_dimension_dict' %
              feature_index)
      elif feature_metadata['encode_type'] == 'array_numerical':
        if (
            self._array_numerical_length_dict
            and feature_index not in self._array_numerical_length_dict
        ):
          raise ValueError(
              'feature_index %d missing in _array_numerical_length_dict'
              % feature_index
          )
      else:
        raise ValueError('Invalid encode_type %s for feature %s' %
                         (feature_metadata['encode_type'], feature_name))