in flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java [68:140]
public NaiveBayesModel fit(Table... inputs) {
Preconditions.checkArgument(inputs.length == 1);
final String featuresCol = getFeaturesCol();
final String labelCol = getLabelCol();
final double smoothing = getSmoothing();
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
DataStream<Tuple2<Vector, Double>> input =
tEnv.toDataStream(inputs[0])
.map(
(MapFunction<Row, Tuple2<Vector, Double>>)
row -> {
Number number = (Number) row.getField(labelCol);
Preconditions.checkNotNull(
number,
"Input data should contain label value.");
Preconditions.checkArgument(
number.intValue() == number.doubleValue(),
"Label value should be indexed number.");
return new Tuple2<>(
(Vector) row.getField(featuresCol),
number.doubleValue());
},
Types.TUPLE(VectorTypeInfo.INSTANCE, Types.DOUBLE));
DataStream<Tuple3<Double, Integer, Double>> feature =
input.flatMap(new ExtractFeatureFunction());
DataStream<Tuple4<Double, Integer, Map<Double, Double>, Integer>> featureWeight =
DataStreamUtils.mapPartition(
feature.keyBy(value -> new Tuple2<>(value.f0, value.f1).hashCode()),
new GenerateFeatureWeightMapFunction(),
Types.TUPLE(
Types.DOUBLE,
Types.INT,
Types.MAP(Types.DOUBLE, Types.DOUBLE),
Types.INT));
DataStream<Tuple3<Double, Integer, Map<Double, Double>[]>> aggregatedArrays =
DataStreamUtils.mapPartition(
featureWeight.keyBy(value -> value.f0),
new AggregateIntoArrayFunction(),
Types.TUPLE(
Types.DOUBLE,
Types.INT,
Types.OBJECT_ARRAY(Types.MAP(Types.DOUBLE, Types.DOUBLE))));
DataStream<NaiveBayesModelData> modelData =
DataStreamUtils.mapPartition(
aggregatedArrays,
new GenerateModelFunction(smoothing),
NaiveBayesModelData.TYPE_INFO);
modelData.getTransformation().setParallelism(1);
Schema schema =
Schema.newBuilder()
.column(
"theta",
DataTypes.ARRAY(
DataTypes.ARRAY(
DataTypes.MAP(
DataTypes.DOUBLE(), DataTypes.DOUBLE()))))
.column("piArray", DataTypes.of(DenseVectorTypeInfo.INSTANCE))
.column("labels", DataTypes.of(DenseVectorTypeInfo.INSTANCE))
.build();
NaiveBayesModel model =
new NaiveBayesModel().setModelData(tEnv.fromDataStream(modelData, schema));
ParamUtils.updateExistingParams(model, paramMap);
return model;
}