public LinearRegressionModel fit()

in flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java [60:111]


    public LinearRegressionModel fit(Table... inputs) {
        Preconditions.checkArgument(inputs.length == 1);
        StreamTableEnvironment tEnv =
                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
        DataStream<LabeledPointWithWeight> trainData =
                tEnv.toDataStream(inputs[0])
                        .map(
                                dataPoint -> {
                                    double weight =
                                            getWeightCol() == null
                                                    ? 1.0
                                                    : ((Number) dataPoint.getField(getWeightCol()))
                                                            .doubleValue();
                                    double label =
                                            ((Number) dataPoint.getField(getLabelCol()))
                                                    .doubleValue();
                                    DenseVector features =
                                            ((Vector) dataPoint.getField(getFeaturesCol()))
                                                    .toDense();
                                    return new LabeledPointWithWeight(features, label, weight);
                                });

        DataStream<DenseVector> initModelData =
                DataStreamUtils.reduce(
                                trainData.map(x -> x.getFeatures().size()),
                                (ReduceFunction<Integer>)
                                        (t0, t1) -> {
                                            Preconditions.checkState(
                                                    t0.equals(t1),
                                                    "The training data should all have same dimensions.");
                                            return t0;
                                        })
                        .map(DenseVector::new);

        Optimizer optimizer =
                new SGD(
                        getMaxIter(),
                        getLearningRate(),
                        getGlobalBatchSize(),
                        getTol(),
                        getReg(),
                        getElasticNet());
        DataStream<DenseVector> rawModelData =
                optimizer.optimize(initModelData, trainData, LeastSquareLoss.INSTANCE);

        DataStream<LinearRegressionModelData> modelData =
                rawModelData.map(LinearRegressionModelData::new);
        LinearRegressionModel model =
                new LinearRegressionModel().setModelData(tEnv.fromDataStream(modelData));
        ParamUtils.updateExistingParams(model, paramMap);
        return model;
    }