sdks/python/apache_beam/ml/anomaly/detectors/pyod_adapter.py (58 lines of code) (raw):

# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import pickle from collections.abc import Iterable from collections.abc import Sequence from typing import Any from typing import Optional import numpy as np import apache_beam as beam from apache_beam.io.filesystems import FileSystems from apache_beam.ml.anomaly.detectors.offline import OfflineDetector from apache_beam.ml.anomaly.specifiable import specifiable from apache_beam.ml.anomaly.thresholds import FixedThreshold from apache_beam.ml.inference.base import KeyedModelHandler from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import _PostProcessingModelHandler from apache_beam.ml.inference.utils import _convert_to_result from pyod.models.base import BaseDetector as PyODBaseDetector # Turn the used ModelHandler into specifiable, but without lazy init. KeyedModelHandler = specifiable( # type: ignore[misc] KeyedModelHandler, on_demand_init=False, just_in_time_init=False) _PostProcessingModelHandler = specifiable( # type: ignore[misc] _PostProcessingModelHandler, on_demand_init=False, just_in_time_init=False) @specifiable class PyODModelHandler(ModelHandler[beam.Row, PredictionResult, PyODBaseDetector]): """Implementation of the ModelHandler interface for PyOD [#]_ Models. The ModelHandler processes input data as `beam.Row` objects. **NOTE:** This API and its implementation are currently under active development and may not be backward compatible. Args: model_uri: The URI specifying the location of the pickled PyOD model. .. [#] https://github.com/yzhao062/pyod """ def __init__(self, model_uri: str): self._model_uri = model_uri def load_model(self) -> PyODBaseDetector: file = FileSystems.open(self._model_uri, 'rb') return pickle.load(file) def run_inference( self, batch: Sequence[beam.Row], model: PyODBaseDetector, inference_args: Optional[dict[str, Any]] = None ) -> Iterable[PredictionResult]: np_batch = [] for row in batch: np_batch.append(np.fromiter(row, dtype=np.float64)) # stack a batch of samples into a 2-D array for better performance vectorized_batch = np.stack(np_batch, axis=0) predictions = model.decision_function(vectorized_batch) return _convert_to_result(batch, predictions, model_id=self._model_uri) class PyODFactory(): @staticmethod def create_detector(model_uri: str, **kwargs) -> OfflineDetector: """A utility function to create OfflineDetector for a PyOD model. **NOTE:** This API and its implementation are currently under active development and may not be backward compatible. Args: model_uri: The URI specifying the location of the pickled PyOD model. **kwargs: Additional keyword arguments. """ model_handler = KeyedModelHandler( PyODModelHandler(model_uri=model_uri)).with_postprocess_fn( OfflineDetector.score_prediction_adapter) m = model_handler.load_model() assert (isinstance(m, PyODBaseDetector)) threshold = float(m.threshold_) detector = OfflineDetector( model_handler, threshold_criterion=FixedThreshold(threshold), **kwargs) # type: ignore[arg-type] return detector