public OnlineLogisticRegressionModel fit()

in flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java [86:138]


    public OnlineLogisticRegressionModel fit(Table... inputs) {
        Preconditions.checkArgument(inputs.length == 1);

        StreamTableEnvironment tEnv =
                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
        DataStream<LogisticRegressionModelData> modelDataStream =
                LogisticRegressionModelDataUtil.getModelDataStream(initModelDataTable);

        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
        TypeInformation pointTypeInfo;

        if (getWeightCol() == null) {
            pointTypeInfo =
                    Types.ROW(
                            inputTypeInfo.getTypeAt(getFeaturesCol()),
                            inputTypeInfo.getTypeAt(getLabelCol()));
        } else {
            pointTypeInfo =
                    Types.ROW(
                            inputTypeInfo.getTypeAt(getFeaturesCol()),
                            inputTypeInfo.getTypeAt(getLabelCol()),
                            inputTypeInfo.getTypeAt(getWeightCol()));
        }

        DataStream<Row> points =
                tEnv.toDataStream(inputs[0])
                        .map(
                                new FeaturesLabelExtractor(
                                        getFeaturesCol(), getLabelCol(), getWeightCol()),
                                pointTypeInfo);

        DataStream<DenseVector> initModelData =
                modelDataStream.map(
                        (MapFunction<LogisticRegressionModelData, DenseVector>)
                                value -> value.coefficient);

        initModelData.getTransformation().setParallelism(1);

        IterationBody body =
                new FtrlIterationBody(
                        getGlobalBatchSize(), getAlpha(), getBeta(), getReg(), getElasticNet());

        DataStream<LogisticRegressionModelData> onlineModelData =
                Iterations.iterateUnboundedStreams(
                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
                        .get(0);

        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
        OnlineLogisticRegressionModel model =
                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
        ParamUtils.updateExistingParams(model, paramMap);
        return model;
    }