in flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java [60:123]
public LogisticRegressionModel fit(Table... inputs) {
Preconditions.checkArgument(inputs.length == 1);
String classificationType = getMultiClass();
Preconditions.checkArgument(
"auto".equals(classificationType) || "binomial".equals(classificationType),
"Multinomial classification is not supported yet. Supported options: [auto, binomial].");
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
DataStream<LabeledPointWithWeight> trainData =
tEnv.toDataStream(inputs[0])
.map(
dataPoint -> {
double weight =
getWeightCol() == null
? 1.0
: ((Number) dataPoint.getField(getWeightCol()))
.doubleValue();
double label =
((Number) dataPoint.getField(getLabelCol()))
.doubleValue();
boolean isBinomial =
Double.compare(0., label) == 0
|| Double.compare(1., label) == 0;
if (!isBinomial) {
throw new RuntimeException(
"Multinomial classification is not supported yet. Supported options: [auto, binomial].");
}
DenseVector features =
((Vector) dataPoint.getField(getFeaturesCol()))
.toDense();
return new LabeledPointWithWeight(features, label, weight);
});
DataStream<DenseVector> initModelData =
DataStreamUtils.reduce(
trainData.map(x -> x.getFeatures().size()),
(ReduceFunction<Integer>)
(t0, t1) -> {
Preconditions.checkState(
t0.equals(t1),
"The training data should all have same dimensions.");
return t0;
})
.map(DenseVector::new);
Optimizer optimizer =
new SGD(
getMaxIter(),
getLearningRate(),
getGlobalBatchSize(),
getTol(),
getReg(),
getElasticNet());
DataStream<DenseVector> rawModelData =
optimizer.optimize(initModelData, trainData, BinaryLogisticLoss.INSTANCE);
DataStream<LogisticRegressionModelData> modelData =
rawModelData.map(vector -> new LogisticRegressionModelData(vector, 0L));
LogisticRegressionModel model =
new LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData));
ParamUtils.updateExistingParams(model, paramMap);
return model;
}