in python/treelite/gallery/sklearn/__init__.py [0:0]
def import_model(sklearn_model):
"""
Load a tree ensemble model from a scikit-learn model object
Parameters
----------
sklearn_model : object of type \
:py:class:`~sklearn.ensemble.RandomForestRegressor` / \
:py:class:`~sklearn.ensemble.RandomForestClassifier` / \
:py:class:`~sklearn.ensemble.GradientBoostingRegressor` / \
:py:class:`~sklearn.ensemble.GradientBoostingClassifier`
Python handle to scikit-learn model
Returns
-------
model : :py:class:`~treelite.Model` object
loaded model
Example
-------
.. code-block:: python
:emphasize-lines: 8
import sklearn.datasets
import sklearn.ensemble
X, y = sklearn.datasets.load_boston(return_X_y=True)
clf = sklearn.ensemble.RandomForestRegressor(n_estimators=10)
clf.fit(X, y)
import treelite.gallery.sklearn
model = treelite.gallery.sklearn.import_model(clf)
"""
class_name = sklearn_model.__class__.__name__
module_name = sklearn_model.__module__.split('.')[0]
if module_name != 'sklearn':
raise Exception('Not a scikit-learn model')
_execfile('common.py')
if class_name == 'RandomForestRegressor':
_execfile('rf_regressor.py')
elif class_name == 'RandomForestClassifier':
if sklearn_model.n_classes_ == 2:
_execfile('rf_classifier.py')
elif sklearn_model.n_classes_ > 2:
_execfile('rf_multi_classifier.py')
else:
raise Exception('n_classes_ must be at least 2')
elif class_name == 'GradientBoostingRegressor':
_execfile('gbm_regressor.py')
elif class_name == 'GradientBoostingClassifier':
if sklearn_model.n_classes_ == 2:
_execfile('gbm_classifier.py')
elif sklearn_model.n_classes_ > 2:
_execfile('gbm_multi_classifier.py')
else:
raise Exception('n_classes_ must be at least 2')
else:
raise Exception('Unsupported model type: only '
'random forests and gradient boosted trees are supported')
return process_model(sklearn_model)