in inference/xgboost_predictor/predictor.py [0:0]
def _extract_model_metadata(self):
"""Extracts info from model metadata and fills member variables.
Raises:
ValueError: An error occurred when:
1. Invalid model type.
2. Label not found.
3. Features not found.
4. Class names not found for {boosted_tree|random_forest}_classifier.
6. Feature index mismatch.
7. Invalid encode type for categorical features.
"""
if 'model_type' not in self._model_metadata or self._model_metadata[
'model_type'] not in [
'boosted_tree_regressor', 'boosted_tree_classifier',
'random_forest_regressor', 'random_forest_classifier',
'boosted_tree_cox', 'boosted_tree_aft'
]:
raise ValueError('Invalid model_type in model_metadata')
self._model_type = self._model_metadata['model_type']
if 'label_col' not in self._model_metadata:
raise ValueError('label_col not found in model_metadata')
self._label_col = self._model_metadata['label_col']
if not self._model_metadata['features']:
raise ValueError('No feature found in model_metadata')
self._feature_names = self._model_metadata['feature_names']
if self._model_type in ['boosted_tree_classifier',
'random_forest_classifier']:
if 'class_names' not in self._model_metadata or not self._model_metadata[
'class_names']:
raise ValueError('No class_names found in model_metadata')
self._class_names = self._model_metadata['class_names']
for feature_index in range(len(self._feature_names)):
feature_name = self._feature_names[feature_index]
self._feature_name_to_index_map[feature_name] = feature_index
feature_metadata = self._model_metadata['features'][feature_name]
if ('encode_type' not in feature_metadata) or (not feature_metadata[
'encode_type']) or (feature_metadata[
'encode_type'] == 'numerical_identity'):
continue
elif feature_metadata['encode_type'] == 'categorical_one_hot':
if feature_index not in self._categorical_one_hot_vocab:
raise ValueError(
'feature_index %d missing in _categorical_one_hot_vocab' %
feature_index)
elif feature_metadata['encode_type'] == 'categorical_target':
if feature_index not in self._categorical_target_vocab:
raise ValueError(
'feature_index %d missing in _categorical_target_vocab' %
feature_index)
elif (feature_metadata[
'encode_type'] == 'categorical_label') or (feature_metadata[
'encode_type'] == 'ohe'):
if feature_index not in self._categorical_label_vocab:
raise ValueError(
'feature_index %d missing in _categorical_label_vocab' %
feature_index)
elif (feature_metadata[
'encode_type'] == 'array_one_hot') or (feature_metadata[
'encode_type'] == 'mhe'):
if feature_index not in self._array_one_hot_vocab:
raise ValueError('feature_index %d missing in _array_one_hot_vocab' %
feature_index)
elif feature_metadata['encode_type'] == 'array_target':
if feature_index not in self._array_target_vocab:
raise ValueError('feature_index %d missing in _array_target_vocab' %
feature_index)
elif feature_metadata['encode_type'] == 'array_struct':
if (self._array_struct_dimension_dict and
feature_index not in self._array_struct_dimension_dict):
raise ValueError(
'feature_index %d missing in _array_struct_dimension_dict' %
feature_index)
elif feature_metadata['encode_type'] == 'array_numerical':
if (
self._array_numerical_length_dict
and feature_index not in self._array_numerical_length_dict
):
raise ValueError(
'feature_index %d missing in _array_numerical_length_dict'
% feature_index
)
else:
raise ValueError('Invalid encode_type %s for feature %s' %
(feature_metadata['encode_type'], feature_name))