in inference/xgboost_predictor/predictor.py [0:0]
def from_path(cls, model_dir):
"""Creates an instance of Predictor using the given path.
Args:
model_dir: The local directory that contains the trained XGBoost model and
the assets including vocabularies and model metadata.
Returns:
An instance of 'Predictor'.
"""
# Keep model name the same as ml::kXgboostFinalModelFilename.
model_path = os.path.join(model_dir, 'model.bst')
model = xgb.Booster(model_file=model_path)
assets_path = os.path.join(model_dir, 'assets')
model_metadata_path = os.path.join(assets_path, 'model_metadata.json')
with open(model_metadata_path) as f:
model_metadata = json.load(f)
txt_list = glob.glob(assets_path + '/*.txt')
categorical_one_hot_vocab = {}
categorical_target_vocab = {}
categorical_label_vocab = {}
array_one_hot_vocab = {}
array_target_vocab = {}
array_struct_dimension_dict = {}
array_numerical_length_dict = {}
for txt_file in txt_list:
categorical_one_hot_found = re.search(r'(\d+)_categorical_one_hot.txt',
txt_file)
categorical_target_found = re.search(r'(\d+)_categorical_target.txt',
txt_file)
categorical_label_found = re.search(r'(\d+)_categorical_label.txt',
txt_file)
categorical_label_found_legacy = re.search(r'(\d+).txt', txt_file)
array_one_hot_found = re.search(r'(\d+)_array_one_hot.txt', txt_file)
array_one_hot_found_legacy = re.search(r'(\d+)_array.txt', txt_file)
array_target_found = re.search(r'(\d+)_array_target.txt', txt_file)
array_struct_found = re.search(r'(\d+)_array_struct_dimension.txt',
txt_file)
array_numerical_found = re.search(
r'(\d+)_array_numerical_length.txt', txt_file
)
if categorical_one_hot_found:
feature_index = int(categorical_one_hot_found.group(1))
with open(txt_file) as f:
categorical_one_hot_vocab[feature_index] = f.read().splitlines()
elif categorical_target_found:
feature_index = int(categorical_target_found.group(1))
target_dict = {}
with open(txt_file) as f:
split_lines = f.read().splitlines()
for line in split_lines:
try:
words = line.split(',')
target_dict[words[0]] = [float(x) for x in words[1:]]
except ValueError:
raise ValueError(
'%s does not have the right format for target encoding' %
(txt_file))
categorical_target_vocab[feature_index] = target_dict
elif categorical_label_found:
feature_index = int(categorical_label_found.group(1))
with open(txt_file) as f:
categorical_label_vocab[feature_index] = f.read().splitlines()
elif categorical_label_found_legacy:
feature_index = int(categorical_label_found_legacy.group(1))
with open(txt_file) as f:
categorical_label_vocab[feature_index] = f.read().splitlines()
elif array_one_hot_found:
feature_index = int(array_one_hot_found.group(1))
with open(txt_file) as f:
array_one_hot_vocab[feature_index] = f.read().splitlines()
elif array_one_hot_found_legacy:
feature_index = int(array_one_hot_found_legacy.group(1))
with open(txt_file) as f:
array_one_hot_vocab[feature_index] = f.read().splitlines()
elif array_target_found:
feature_index = int(array_target_found.group(1))
target_dict = {}
with open(txt_file) as f:
split_lines = f.read().splitlines()
for line in split_lines:
try:
words = line.split(',')
target_dict[words[0]] = [float(x) for x in words[1:]]
except ValueError:
raise ValueError(
'%s does not have the right format for target encoding' %
(txt_file))
array_target_vocab[feature_index] = target_dict
elif array_struct_found:
feature_index = int(array_struct_found.group(1))
with open(txt_file) as f:
try:
dimension = int(f.read().strip())
array_struct_dimension_dict[feature_index] = dimension
except ValueError:
raise ValueError(
'%s does not have the right format for array struct dimension' %
(txt_file))
elif array_numerical_found:
feature_index = int(array_numerical_found.group(1))
with open(txt_file) as f:
try:
length = int(f.read().strip())
array_numerical_length_dict[feature_index] = length
except ValueError:
raise ValueError(
'%s does not have the right format for array numerical length'
% (txt_file)
)
return cls(
model,
model_metadata,
categorical_one_hot_vocab,
categorical_target_vocab,
categorical_label_vocab,
array_one_hot_vocab,
array_target_vocab,
array_struct_dimension_dict,
array_numerical_length_dict,
)