in flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java [87:165]
public KBinsDiscretizerModel fit(Table... inputs) {
Preconditions.checkArgument(inputs.length == 1);
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
String inputCol = getInputCol();
String strategy = getStrategy();
int numBins = getNumBins();
DataStream<DenseVector> inputData =
tEnv.toDataStream(inputs[0])
.map(
(MapFunction<Row, DenseVector>)
value -> ((Vector) value.getField(inputCol)).toDense());
DataStream<DenseVector> preprocessedData;
if (strategy.equals(UNIFORM)) {
preprocessedData =
inputData
.transform(
"reduceInEachPartition",
inputData.getType(),
new MinMaxReduceFunctionOperator())
.transform(
"reduceInFinalPartition",
inputData.getType(),
new MinMaxReduceFunctionOperator())
.setParallelism(1);
} else {
preprocessedData =
DataStreamUtils.sample(
inputData, getSubSamples(), getClass().getName().hashCode());
}
DataStream<KBinsDiscretizerModelData> modelData =
DataStreamUtils.mapPartition(
preprocessedData,
new MapPartitionFunction<DenseVector, KBinsDiscretizerModelData>() {
@Override
public void mapPartition(
Iterable<DenseVector> iterable,
Collector<KBinsDiscretizerModelData> collector) {
List<DenseVector> list = new ArrayList<>();
iterable.iterator().forEachRemaining(list::add);
if (list.size() == 0) {
throw new RuntimeException("The training set is empty.");
}
double[][] binEdges;
switch (strategy) {
case UNIFORM:
binEdges = findBinEdgesWithUniformStrategy(list, numBins);
break;
case QUANTILE:
binEdges = findBinEdgesWithQuantileStrategy(list, numBins);
break;
case KMEANS:
binEdges = findBinEdgesWithKMeansStrategy(list, numBins);
break;
default:
throw new UnsupportedOperationException(
"Unsupported "
+ STRATEGY
+ " type: "
+ strategy
+ ".");
}
collector.collect(new KBinsDiscretizerModelData(binEdges));
}
});
modelData.getTransformation().setParallelism(1);
KBinsDiscretizerModel model =
new KBinsDiscretizerModel().setModelData(tEnv.fromDataStream(modelData));
ParamUtils.updateExistingParams(model, getParamMap());
return model;
}