in src/neo_loader/sklearn_model_loader.py [0:0]
def load_model(self) -> None:
model = self.__get_sklearn_model_from_model_artifacts()
if self.transform_func == "transform":
if len(self.data_shape["input"]) != 2:
raise RuntimeError("InputConfiguration: InputShape for Sklearn model must have two dimensions, but got {}.".format(len(self.data_shape["input"])))
if self.data_shape["input"][-1] == -1:
raise RuntimeError("InputConfiguration: InputShape for Sklearn model must have a static value for the second dimension, equal to the number of input columns or features.")
for _, transformer in model.feature_transformer.steps:
if type(transformer).__name__ == "ColumnTransformer":
dropped_transformers = []
for name, pipeline, cols in transformer.transformers_:
if pipeline == "drop":
continue
dropped_transformers.append((name, pipeline, cols))
mod = pipeline.steps[0][1]
if type(mod).__name__ == "ThresholdOneHotEncoder" or type(mod).__name__ == "RobustOrdinalEncoder":
self.__build_numeric_mapping(mod.categories_, cols)
if name == "datetime_processing":
self.date_col = cols
transformer.transformers_ = dropped_transformers
self.__update_categorical_mapping(self.data_shape["input"][-1])
elif self.transform_func == "inverse_transform":
if type(model.target_transformer).__name__ == 'RobustLabelEncoder':
self.__build_inverse_label_mapping(model.target_transformer)
try:
num_rows = self.data_shape["input"][0] if self.data_shape["input"][0] != -1 else relay.Any()
num_cols = self.data_shape["input"][-1] if self.data_shape["input"][-1] != -1 else relay.Any()
self._relay_module_object, self._params = relay.frontend.from_auto_ml(model, (num_rows, num_cols), FLOAT_32, self.transform_func)
self._relay_module_object = self.dynamic_to_static(self._relay_module_object)
self.update_missing_metadata()
except OpError:
raise
except Exception as e:
logger.exception("Failed to convert Scikit-Learn model. %s" % repr(e))
msg = "InputConfiguration: TVM cannot convert Scikit-Learn model. Please make sure the framework you selected is correct. {}".format(e)
raise RuntimeError(msg)