flink-ml-python/pyflink/ml/feature/lsh.py (90 lines of code) (raw):
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import typing
from abc import ABC
from pyflink.java_gateway import get_gateway
from pyflink.table import Table
from pyflink.util.java_utils import to_jarray
from pyflink.ml.linalg import Vector, DenseVector, SparseVector
from pyflink.ml.param import Param, IntParam, ParamValidators
from pyflink.ml.wrapper import JavaWithParams
from pyflink.ml.feature.common import JavaFeatureEstimator, JavaFeatureModel
from pyflink.ml.common.param import HasInputCol, HasOutputCol, HasSeed
class _LSHModelParams(JavaWithParams,
HasInputCol,
HasOutputCol):
"""
Params for :class:`LSHModel`
"""
def __init__(self, java_params):
super(_LSHModelParams, self).__init__(java_params)
class _LSHParams(_LSHModelParams):
"""
Params for :class:`LSH`
"""
"""
Param for the number of hash tables used in LSH OR-amplification.
OR-amplification can be used to reduce the false negative rate. Higher values of this
param lead to a reduced false negative rate, at the expense of added computational
complexity.
"""
NUM_HASH_TABLES: Param[int] = IntParam(
"num_hash_tables", "Number of hash tables.", 1, ParamValidators.gt_eq(1)
)
"""
Param for the number of hash functions per hash table used in LSH AND-amplification.
AND-amplification can be used to reduce the false positive rate. Higher values of this
param lead to a reduced false positive rate, at the expense of added computational
complexity.
"""
NUM_HASH_FUNCTIONS_PER_TABLE: Param[int] = IntParam(
"num_hash_functions_per_table",
"Number of hash functions per table.",
1,
ParamValidators.gt_eq(1))
def __init__(self, java_params):
super(_LSHParams, self).__init__(java_params)
def set_num_hash_tables(self, value: int):
return typing.cast(_LSHParams, self.set(self.NUM_HASH_TABLES, value))
def get_num_hash_tables(self):
return self.get(self.NUM_HASH_TABLES)
@property
def num_hash_tables(self):
return self.get_num_hash_tables()
def set_num_hash_functions_per_table(self, value: int):
return typing.cast(_LSHParams, self.set(self.NUM_HASH_FUNCTIONS_PER_TABLE, value))
def get_num_hash_functions_per_table(self):
return self.get(self.NUM_HASH_FUNCTIONS_PER_TABLE)
@property
def num_hash_functions_per_table(self):
return self.get_num_hash_functions_per_table()
class _MinHashLSHParams(_LSHParams, HasSeed):
"""
Params for :class:`MinHashLSH`
"""
def __init__(self, java_params):
super(_MinHashLSHParams, self).__init__(java_params)
class _LSH(JavaFeatureEstimator, ABC):
"""
Base class for estimators that support LSH (Locality-sensitive hashing) algorithm for different
metrics (e.g., Jaccard distance).
The basic idea of LSH is to use a set of hash functions to map input vectors into different
buckets, where closer vectors are expected to be in the same bucket with higher probabilities.
In detail, each input vector is hashed by all functions. To decide whether two input vectors
are mapped into the same bucket, two mechanisms for assigning buckets are proposed as follows.
<ul>
<li>AND-amplification: The two input vectors are defined to be in the same bucket as long as
ALL of the hash value matches.
<li>OR-amplification: The two input vectors are defined to be in the same bucket as long as
ANY of the hash value matches.
</ul>
See: <a href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing">
Locality-sensitive_hashing</a>.
"""
def __init__(self):
super(_LSH, self).__init__()
@classmethod
def _java_estimator_package_name(cls) -> str:
return "lsh"
class _LSHModel(JavaFeatureModel, ABC):
"""
Base class for LSH model.
In addition to transforming input feature vectors to multiple hash values, it also supports
approximate nearest neighbors search within a dataset regarding a key vector and approximate
similarity join between two datasets.
"""
def __init__(self, java_model):
super(_LSHModel, self).__init__(java_model)
@classmethod
def _java_model_package_name(cls) -> str:
return "lsh"
def approx_nearest_neighbors(self, dataset: Table, key: Vector, k: int,
dist_col: str = 'distCol'):
"""
Approximately finds at most k items from a dataset which have the closest distance to a
given item. If the `outputCol` is missing in the given dataset, this method transforms the
dataset with the model at first.
:param dataset: The dataset in which to to search for nearest neighbors.
:param key: The item to search for.
:param k: The maximum number of nearest neighbors.
:param dist_col: The output column storing the distance between each neighbor and the key.
:return: A dataset containing at most k items closest to the key with a column named
`distCol` appended.
"""
j_vectors = get_gateway().jvm.org.apache.flink.ml.linalg.Vectors
if isinstance(key, (DenseVector,)):
j_key = j_vectors.dense(to_jarray(get_gateway().jvm.double, key.values.tolist()))
elif isinstance(key, (SparseVector,)):
# noinspection PyProtectedMember
j_key = j_vectors.sparse(
key.size(),
to_jarray(get_gateway().jvm.int, key._indices.tolist()),
to_jarray(get_gateway().jvm.double, key._values.tolist())
)
else:
raise TypeError(f'Key {key} must be an instance of Vector.')
# noinspection PyProtectedMember
return Table(self._java_obj.approxNearestNeighbors(
dataset._j_table, j_key, k, dist_col), self._t_env)
def approx_similarity_join(self, dataset_a: Table, dataset_b: Table, threshold: float,
id_col: str, dist_col: str = 'distCol'):
"""
Joins two datasets to approximately find all pairs of rows whose distance are smaller than
or equal to the threshold. If the `outputCol` is missing in either dataset, this method
transforms the dataset at first.
:param dataset_a: One dataset.
:param dataset_b: The other dataset.
:param threshold: The distance threshold.
:param id_col: A column in the two datasets to identify each row.
:param dist_col: The output column storing the distance between each pair of rows.
:return: A joined dataset containing pairs of rows. The original rows are in columns
"dataset_a" and "dataset_b", and a column "distCol" is added to show the distance
between each pair.
"""
# noinspection PyProtectedMember
return Table(self._java_obj.approxSimilarityJoin(
dataset_a._j_table, dataset_b._j_table,
threshold, id_col, dist_col), self._t_env)
class MinHashLSHModel(_LSHModel, _LSHModelParams):
"""
A Model which generates hash values using the model data computed by :class:`MinHashLSH`.
"""
def __init__(self, java_model=None):
super(MinHashLSHModel, self).__init__(java_model)
@classmethod
def _java_model_class_name(cls) -> str:
return "MinHashLSHModel"
class MinHashLSH(_LSH, _MinHashLSHParams):
"""
An Estimator that implements the MinHash LSH algorithm, which supports LSH for Jaccard distance.
The input could be dense or sparse vectors. Each input vector must have at least one non-zero
index and all non-zero values are treated as binary "1" values. The sizes of input vectors
should be same and not larger than a predefined prime (i.e., 2038074743).
See: <a href="https://en.wikipedia.org/wiki/MinHash">MinHash</a>.
"""
def __init__(self):
super(MinHashLSH, self).__init__()
@classmethod
def _create_model(cls, java_model) -> MinHashLSHModel:
return MinHashLSHModel(java_model)
@classmethod
def _java_estimator_class_name(cls) -> str:
return "MinHashLSH"