in flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java [86:138]
public OnlineLogisticRegressionModel fit(Table... inputs) {
Preconditions.checkArgument(inputs.length == 1);
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
DataStream<LogisticRegressionModelData> modelDataStream =
LogisticRegressionModelDataUtil.getModelDataStream(initModelDataTable);
RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
TypeInformation pointTypeInfo;
if (getWeightCol() == null) {
pointTypeInfo =
Types.ROW(
inputTypeInfo.getTypeAt(getFeaturesCol()),
inputTypeInfo.getTypeAt(getLabelCol()));
} else {
pointTypeInfo =
Types.ROW(
inputTypeInfo.getTypeAt(getFeaturesCol()),
inputTypeInfo.getTypeAt(getLabelCol()),
inputTypeInfo.getTypeAt(getWeightCol()));
}
DataStream<Row> points =
tEnv.toDataStream(inputs[0])
.map(
new FeaturesLabelExtractor(
getFeaturesCol(), getLabelCol(), getWeightCol()),
pointTypeInfo);
DataStream<DenseVector> initModelData =
modelDataStream.map(
(MapFunction<LogisticRegressionModelData, DenseVector>)
value -> value.coefficient);
initModelData.getTransformation().setParallelism(1);
IterationBody body =
new FtrlIterationBody(
getGlobalBatchSize(), getAlpha(), getBeta(), getReg(), getElasticNet());
DataStream<LogisticRegressionModelData> onlineModelData =
Iterations.iterateUnboundedStreams(
DataStreamList.of(initModelData), DataStreamList.of(points), body)
.get(0);
Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
OnlineLogisticRegressionModel model =
new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
ParamUtils.updateExistingParams(model, paramMap);
return model;
}