flink-ml-python/pyflink/ml/wrapper.py (283 lines of code) (raw):

################################################################################ # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ import pickle from abc import ABC, abstractmethod from typing import List, Dict, Any from py4j.java_gateway import JavaObject, get_java_class from pyflink.common import typeinfo, Time, Row, RowKind from pyflink.common.typeinfo import _from_java_type, TypeInformation, _is_instance_of, Types, \ ExternalTypeInfo, RowTypeInfo, TupleTypeInfo from pyflink.datastream import utils from pyflink.datastream.utils import pickled_bytes_to_python_obj from pyflink.java_gateway import get_gateway from pyflink.table import Table, StreamTableEnvironment, Expression from pyflink.util.java_utils import to_jarray from pyflink.ml.api import Model, Transformer, AlgoOperator, Stage, Estimator from pyflink.ml.linalg import DenseVectorTypeInfo, SparseVectorTypeInfo, DenseMatrixTypeInfo, \ VectorTypeInfo, DenseVector from pyflink.ml.param import Param, WithParams, StringArrayParam, IntArrayParam, VectorParam, \ FloatArrayParam, FloatArrayArrayParam, WindowsParam from pyflink.ml.common.window import GlobalWindows, CountTumblingWindows, \ EventTimeTumblingWindows, ProcessingTimeTumblingWindows, EventTimeSessionWindows, \ ProcessingTimeSessionWindows _from_java_type_alias = _from_java_type def _from_java_type_wrapper(j_type_info: JavaObject) -> TypeInformation: gateway = get_gateway() JGenericTypeInfo = gateway.jvm.org.apache.flink.api.java.typeutils.GenericTypeInfo if _is_instance_of(j_type_info, JGenericTypeInfo): JClass = j_type_info.getTypeClass() if JClass == get_java_class(gateway.jvm.org.apache.flink.ml.linalg.DenseVector): return DenseVectorTypeInfo() elif JClass == get_java_class(gateway.jvm.org.apache.flink.ml.linalg.SparseVector): return SparseVectorTypeInfo() elif JClass == get_java_class(gateway.jvm.org.apache.flink.ml.linalg.DenseMatrix): return DenseMatrixTypeInfo() elif JClass == get_java_class(gateway.jvm.org.apache.flink.ml.linalg.Vector): return VectorTypeInfo() return _from_java_type_alias(j_type_info) typeinfo._from_java_type = _from_java_type_wrapper # TODO: Remove this class after Flink ML depends on a Flink version # with FLINK-30168 and FLINK-29477 fixed. def convert_to_python_obj_wrapper(data, type_info): if type_info == Types.PICKLED_BYTE_ARRAY(): return pickle.loads(data) elif isinstance(type_info, ExternalTypeInfo): return convert_to_python_obj_wrapper(data, type_info._type_info) else: gateway = get_gateway() pickle_bytes = gateway.jvm.org.apache.flink.ml.python.PythonBridgeUtils. \ getPickledBytesFromJavaObject(data, type_info.get_java_type_info()) if isinstance(type_info, RowTypeInfo) or isinstance(type_info, TupleTypeInfo): field_data = zip(list(pickle_bytes[1:]), type_info.get_field_types()) fields = [] for data, field_type in field_data: if len(data) == 0: fields.append(None) else: fields.append(pickled_bytes_to_python_obj(data, field_type)) if isinstance(type_info, RowTypeInfo): return Row.of_kind(RowKind(int.from_bytes(pickle_bytes[0], 'little')), *fields) else: return tuple(fields) else: return pickled_bytes_to_python_obj(pickle_bytes, type_info) utils.convert_to_python_obj = convert_to_python_obj_wrapper class JavaWrapper(ABC): """ Wrapper class for a Java object. """ def __init__(self, java_obj): self._java_obj = java_obj class JavaWithParams(WithParams, JavaWrapper): """ Wrapper class for a Java WithParams. """ def __init__(self, java_params): super(JavaWithParams, self).__init__(java_params) def set(self, param: Param, value) -> WithParams: if type(param) in _map_java_param_converter: converter = _map_java_param_converter[type(param)] else: converter = default_converter java_param_name = snake_to_camel(param.name) set_method_name = ''.join(['set', java_param_name[0].upper(), java_param_name[1:]]) gateway = get_gateway() gateway.jvm.org.apache.flink.iteration.utils.ReflectionUtils.callMethod( self._java_obj, self._java_obj.getClass(), set_method_name, to_jarray(gateway.jvm.Object, [converter.to_java(value)]) ) return self def get(self, param: Param): if type(param) in _map_java_param_converter: converter = _map_java_param_converter[type(param)] else: converter = default_converter java_param_name = snake_to_camel(param.name) get_method_name = ''.join(['get', java_param_name[0].upper(), java_param_name[1:]]) gateway = get_gateway() result = gateway.jvm.org.apache.flink.iteration.utils.ReflectionUtils.callMethod( self._java_obj, self._java_obj.getClass(), get_method_name ) return converter.to_python(result) def get_param_map(self) -> Dict[Param, Any]: return self._java_obj.getParamMap() class JavaStage(Stage, JavaWithParams, ABC): """ Wrapper class for a Java Stage. """ def __init__(self, java_stage): super(JavaStage, self).__init__(java_stage) def save(self, path: str) -> None: self._java_obj.save(path) @classmethod def load(cls, t_env: StreamTableEnvironment, path: str): java_model = _to_java_reference(cls._java_stage_path()).load(t_env._j_tenv, path) instance = cls(java_model) return instance @classmethod @abstractmethod def _java_stage_path(cls) -> str: pass class JavaAlgoOperator(AlgoOperator, JavaStage, ABC): """ Wrapper class for a Java AlgoOperator. """ def __init__(self, java_algo_operator): if java_algo_operator is None: super(JavaAlgoOperator, self).__init__(_to_java_reference(self._java_stage_path())()) else: super(JavaAlgoOperator, self).__init__(java_algo_operator) def transform(self, *inputs: Table) -> List[Table]: results = self._java_obj.transform(_to_java_tables(*inputs)) return [Table(t, inputs[0]._t_env) for t in results] class JavaTransformer(Transformer, JavaAlgoOperator, ABC): """ Wrapper class for a Java Transformer. """ def __init__(self, java_transformer): super(JavaTransformer, self).__init__(java_transformer) class JavaModel(Model, JavaTransformer, ABC): """ Wrapper class for a Java Model. """ def __init__(self, java_model): super(JavaModel, self).__init__(java_model) self._t_env = None def set_model_data(self, *inputs: Table) -> Model: self._t_env = inputs[0]._t_env self._java_obj.setModelData(_to_java_tables(*inputs)) return self def get_model_data(self) -> List[Table]: return [Table(t, self._t_env) for t in self._java_obj.getModelData()] class JavaEstimator(Estimator, JavaStage, ABC): """ Wrapper class for a Java Estimator. """ def __init__(self): super(JavaEstimator, self).__init__(_new_java_obj(self._java_stage_path())) def fit(self, *inputs: Table) -> Model: return self._create_model(self._java_obj.fit(_to_java_tables(*inputs))) @classmethod def _create_model(cls, java_model) -> Model: """ Creates a model from the input Java model reference. """ pass @classmethod def load(cls, t_env: StreamTableEnvironment, path: str): """ Instantiates a new stage instance based on the data read from the given path. """ java_estimator = _to_java_reference(cls._java_stage_path()).load(t_env._j_tenv, path) instance = cls() instance._java_obj = java_estimator return instance class JavaParamConverter(ABC): @abstractmethod def to_java(self, value): pass @abstractmethod def to_python(self, value): pass class DefaultJavaParamConverter(JavaParamConverter): def to_java(self, value): return value def to_python(self, value): return value class IntArrayJavaPramConverter(JavaParamConverter): def to_java(self, value): return to_jarray(get_gateway().jvm.java.lang.Integer, value) def to_python(self, value): return tuple(value[i] for i in range(len(value))) class FloatArrayJavaPramConverter(JavaParamConverter): def to_java(self, value): return to_jarray(get_gateway().jvm.java.lang.Double, value) def to_python(self, value): return tuple(value[i] for i in range(len(value))) class VectorJavaParamConverter(JavaParamConverter): def to_java(self, value): jarray = to_jarray(get_gateway().jvm.double, value.to_array()) return get_gateway().jvm.org.apache.flink.ml.linalg.DenseVector(jarray) def to_python(self, value): return DenseVector(tuple(value.get(i) for i in range(value.size()))) class WindowsJavaParamConverter(JavaParamConverter): @staticmethod def _to_java_time(time: Time): return get_gateway().jvm.org.apache.flink.api.common.time.Time.milliseconds( time.to_milliseconds()) @staticmethod def _to_python_time(time) -> Time: return Time.milliseconds(time.toMilliseconds()) def to_java(self, value): java_window_package = get_gateway().jvm.org.apache.flink.ml.common.window if isinstance(value, GlobalWindows): return java_window_package.GlobalWindows.getInstance() elif isinstance(value, CountTumblingWindows): return java_window_package.CountTumblingWindows.of(value.size) elif isinstance(value, EventTimeTumblingWindows): return java_window_package.EventTimeTumblingWindows.of( WindowsJavaParamConverter._to_java_time(value.size)) elif isinstance(value, ProcessingTimeTumblingWindows): return java_window_package.ProcessingTimeTumblingWindows.of( WindowsJavaParamConverter._to_java_time(value.size)) elif isinstance(value, EventTimeSessionWindows): return java_window_package.EventTimeSessionWindows.withGap( WindowsJavaParamConverter._to_java_time(value.gap)) elif isinstance(value, ProcessingTimeSessionWindows): return java_window_package.ProcessingTimeSessionWindows.withGap( WindowsJavaParamConverter._to_java_time(value.gap)) else: raise TypeError(f'Python object {str(value)}\' cannot be converted to Java object') def to_python(self, value): if value.getClass().getName() == \ "org.apache.flink.ml.common.window.GlobalWindows": return GlobalWindows() elif value.getClass().getName() == \ "org.apache.flink.ml.common.window.CountTumblingWindows": return CountTumblingWindows.of(value.getSize()) elif value.getClass().getName() == \ "org.apache.flink.ml.common.window.EventTimeTumblingWindows": return EventTimeTumblingWindows.of( WindowsJavaParamConverter._to_python_time(value.getSize())) elif value.getClass().getName() == \ "org.apache.flink.ml.common.window.ProcessingTimeTumblingWindows": return ProcessingTimeTumblingWindows.of( WindowsJavaParamConverter._to_python_time(value.getSize())) elif value.getClass().getName() == \ "org.apache.flink.ml.common.window.EventTimeSessionWindows": return EventTimeSessionWindows.with_gap( WindowsJavaParamConverter._to_python_time(value.getGap())) elif value.getClass().getName() == \ "org.apache.flink.ml.common.window.ProcessingTimeSessionWindows": return ProcessingTimeSessionWindows.with_gap( WindowsJavaParamConverter._to_python_time(value.getGap())) else: raise TypeError(f'Java object {str(value)}\' cannot be converted to Python object') class StringArrayJavaParamConverter(JavaParamConverter): def to_java(self, value): return to_jarray(get_gateway().jvm.java.lang.String, value) def to_python(self, value): return tuple(value[i] for i in range(len(value))) class FloatArrayArrayJavaPramConverter(JavaParamConverter): def to_java(self, value): n = len(value) m = len(value[0]) j_arr = get_gateway().new_array(get_gateway().jvm.java.lang.Double, n, m) for i in range(n): for j in range(m): j_arr[i][j] = value[i][j] return j_arr def to_python(self, value): n = len(value) m = len(value[0]) arr = [] for i in range(n): l = [] for j in range(m): l.append(value[i][j]) arr.append(tuple(l)) return tuple(arr) default_converter = DefaultJavaParamConverter() _map_java_param_converter = { IntArrayParam: IntArrayJavaPramConverter(), FloatArrayParam: FloatArrayJavaPramConverter(), FloatArrayArrayParam: FloatArrayArrayJavaPramConverter(), StringArrayParam: StringArrayJavaParamConverter(), VectorParam: VectorJavaParamConverter(), WindowsParam: WindowsJavaParamConverter(), Param: default_converter } def snake_to_camel(method_name): output = ''.join(x.capitalize() or '_' for x in method_name.split('_')) return output[0].lower() + output[1:] def _to_java_reference(java_class: str): java_obj = get_gateway().jvm for name in java_class.split("."): java_obj = getattr(java_obj, name) return java_obj def _new_java_obj(java_class: str, *java_args): """ Returns a new Java object. """ java_obj = _to_java_reference(java_class) return java_obj(*java_args) def _to_java_tables(*inputs: Table): """ Converts Python Tables to Java tables. """ gateway = get_gateway() return to_jarray(gateway.jvm.org.apache.flink.table.api.Table, [t._j_table for t in inputs]) def call_java_table_function(java_table_function_name: str, *args): _function = get_gateway().jvm for member_name in java_table_function_name.split('.'): _function = _function.__getattr__(member_name) return Expression(_function(to_jarray( get_gateway().jvm.java.lang.Object, [expression._j_expr for expression in args])))