docker_images/sklearn/app/pipelines/common.py (92 lines of code) (raw):

import json import logging import warnings from abc import abstractmethod from pathlib import Path from typing import Any import joblib import skops.io as sio from app.pipelines import Pipeline from huggingface_hub import snapshot_download logger = logging.getLogger(__name__) DEFAULT_FILENAME = "sklearn_model.joblib" class SklearnBasePipeline(Pipeline): """Base class for sklearn-based inference pipelines Concrete implementations should add two methods: - `_get_output`: Method to generate model predictions - `__call__`: Should delegate to handle_call, add docstring and type annotations. """ def __init__(self, model_id: str): cached_folder = snapshot_download(repo_id=model_id) self._load_warnings = [] self._load_exception = None try: with open(Path(cached_folder) / "config.json") as f: # this is the default path for configuration of a scikit-learn # project. If the project is created using `skops`, it should have # this file. config = json.load(f) except Exception: config = dict() warnings.warn("`config.json` does not exist or is invalid.") self.model_file = ( config.get("sklearn", {}).get("model", {}).get("file", DEFAULT_FILENAME) ) self.model_format = config.get("sklearn", {}).get("model_format", "pickle") try: with warnings.catch_warnings(record=True) as record: if self.model_format == "pickle": self.model = joblib.load( open(Path(cached_folder) / self.model_file, "rb") ) elif self.model_format == "skops": self.model = sio.load( file=Path(cached_folder) / self.model_file, trusted=True ) if len(record) > 0: # if there's a warning while loading the model, we save it so # that it can be raised to the user when __call__ is called. self._load_warnings += record except Exception as e: # if there is an exception while loading the model, we save it to # raise the write error when __call__ is called. self._load_exception = e # use column names from the config file if available, to give the data # to the model in the right order. self.columns = config.get("sklearn", {}).get("columns", None) @abstractmethod def _get_output(self, inputs: Any) -> Any: raise NotImplementedError( "Implement this method to get the model output (prediction)" ) def __call__(self, inputs: Any) -> Any: """Handle call for getting the model prediction This method is responsible for handling all possible errors and warnings. To get the actual prediction, implement the `_get_output` method. The types of the inputs and output depend on the specific task being implemented. """ if self._load_exception: # there has been an error while loading the model. We need to raise # that, and can't call predict on the model. raise ValueError( "An error occurred while loading the model: " f"{str(self._load_exception)}" ) _warnings = [] if self.columns: # TODO: we should probably warn if columns are not configured, we # really do need them. given_cols = set(inputs["data"].keys()) expected = set(self.columns) extra = given_cols - expected missing = expected - given_cols if extra: _warnings.append( f"The following columns were given but not expected: {extra}" ) if missing: _warnings.append( f"The following columns were expected but not given: {missing}" ) exception = None try: with warnings.catch_warnings(record=True) as record: res = self._get_output(inputs) except Exception as e: exception = e for warning in record: _warnings.append(f"{warning.category.__name__}({warning.message})") for warning in self._load_warnings: _warnings.append(f"{warning.category.__name__}({warning.message})") if _warnings: for warning in _warnings: logger.warning(warning) if not exception: # we raise an error if there are any warnings, so that routes.py # can catch and return a non 200 status code. error = { "error": "There were warnings while running the model.", "output": res, "warnings": _warnings, # see issue #96 } raise ValueError(json.dumps(error)) else: # if there was an exception, we raise it so that routes.py can # catch and return a non 200 status code. raise exception if exception: raise exception return res