public LogisticRegressionModel fit()

in flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java [60:123]


    public LogisticRegressionModel fit(Table... inputs) {
        Preconditions.checkArgument(inputs.length == 1);
        String classificationType = getMultiClass();
        Preconditions.checkArgument(
                "auto".equals(classificationType) || "binomial".equals(classificationType),
                "Multinomial classification is not supported yet. Supported options: [auto, binomial].");
        StreamTableEnvironment tEnv =
                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();

        DataStream<LabeledPointWithWeight> trainData =
                tEnv.toDataStream(inputs[0])
                        .map(
                                dataPoint -> {
                                    double weight =
                                            getWeightCol() == null
                                                    ? 1.0
                                                    : ((Number) dataPoint.getField(getWeightCol()))
                                                            .doubleValue();
                                    double label =
                                            ((Number) dataPoint.getField(getLabelCol()))
                                                    .doubleValue();
                                    boolean isBinomial =
                                            Double.compare(0., label) == 0
                                                    || Double.compare(1., label) == 0;
                                    if (!isBinomial) {
                                        throw new RuntimeException(
                                                "Multinomial classification is not supported yet. Supported options: [auto, binomial].");
                                    }
                                    DenseVector features =
                                            ((Vector) dataPoint.getField(getFeaturesCol()))
                                                    .toDense();
                                    return new LabeledPointWithWeight(features, label, weight);
                                });

        DataStream<DenseVector> initModelData =
                DataStreamUtils.reduce(
                                trainData.map(x -> x.getFeatures().size()),
                                (ReduceFunction<Integer>)
                                        (t0, t1) -> {
                                            Preconditions.checkState(
                                                    t0.equals(t1),
                                                    "The training data should all have same dimensions.");
                                            return t0;
                                        })
                        .map(DenseVector::new);

        Optimizer optimizer =
                new SGD(
                        getMaxIter(),
                        getLearningRate(),
                        getGlobalBatchSize(),
                        getTol(),
                        getReg(),
                        getElasticNet());
        DataStream<DenseVector> rawModelData =
                optimizer.optimize(initModelData, trainData, BinaryLogisticLoss.INSTANCE);

        DataStream<LogisticRegressionModelData> modelData =
                rawModelData.map(vector -> new LogisticRegressionModelData(vector, 0L));
        LogisticRegressionModel model =
                new LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData));
        ParamUtils.updateExistingParams(model, paramMap);
        return model;
    }