public NaiveBayesModel fit()

in flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java [68:140]


    public NaiveBayesModel fit(Table... inputs) {
        Preconditions.checkArgument(inputs.length == 1);

        final String featuresCol = getFeaturesCol();
        final String labelCol = getLabelCol();
        final double smoothing = getSmoothing();

        StreamTableEnvironment tEnv =
                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
        DataStream<Tuple2<Vector, Double>> input =
                tEnv.toDataStream(inputs[0])
                        .map(
                                (MapFunction<Row, Tuple2<Vector, Double>>)
                                        row -> {
                                            Number number = (Number) row.getField(labelCol);
                                            Preconditions.checkNotNull(
                                                    number,
                                                    "Input data should contain label value.");
                                            Preconditions.checkArgument(
                                                    number.intValue() == number.doubleValue(),
                                                    "Label value should be indexed number.");
                                            return new Tuple2<>(
                                                    (Vector) row.getField(featuresCol),
                                                    number.doubleValue());
                                        },
                                Types.TUPLE(VectorTypeInfo.INSTANCE, Types.DOUBLE));

        DataStream<Tuple3<Double, Integer, Double>> feature =
                input.flatMap(new ExtractFeatureFunction());

        DataStream<Tuple4<Double, Integer, Map<Double, Double>, Integer>> featureWeight =
                DataStreamUtils.mapPartition(
                        feature.keyBy(value -> new Tuple2<>(value.f0, value.f1).hashCode()),
                        new GenerateFeatureWeightMapFunction(),
                        Types.TUPLE(
                                Types.DOUBLE,
                                Types.INT,
                                Types.MAP(Types.DOUBLE, Types.DOUBLE),
                                Types.INT));

        DataStream<Tuple3<Double, Integer, Map<Double, Double>[]>> aggregatedArrays =
                DataStreamUtils.mapPartition(
                        featureWeight.keyBy(value -> value.f0),
                        new AggregateIntoArrayFunction(),
                        Types.TUPLE(
                                Types.DOUBLE,
                                Types.INT,
                                Types.OBJECT_ARRAY(Types.MAP(Types.DOUBLE, Types.DOUBLE))));

        DataStream<NaiveBayesModelData> modelData =
                DataStreamUtils.mapPartition(
                        aggregatedArrays,
                        new GenerateModelFunction(smoothing),
                        NaiveBayesModelData.TYPE_INFO);
        modelData.getTransformation().setParallelism(1);

        Schema schema =
                Schema.newBuilder()
                        .column(
                                "theta",
                                DataTypes.ARRAY(
                                        DataTypes.ARRAY(
                                                DataTypes.MAP(
                                                        DataTypes.DOUBLE(), DataTypes.DOUBLE()))))
                        .column("piArray", DataTypes.of(DenseVectorTypeInfo.INSTANCE))
                        .column("labels", DataTypes.of(DenseVectorTypeInfo.INSTANCE))
                        .build();

        NaiveBayesModel model =
                new NaiveBayesModel().setModelData(tEnv.fromDataStream(modelData, schema));
        ParamUtils.updateExistingParams(model, paramMap);
        return model;
    }