in flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java [58:113]
public LinearSVCModel fit(Table... inputs) {
Preconditions.checkArgument(inputs.length == 1);
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
DataStream<LabeledPointWithWeight> trainData =
tEnv.toDataStream(inputs[0])
.map(
dataPoint -> {
double weight =
getWeightCol() == null
? 1.0
: ((Number) dataPoint.getField(getWeightCol()))
.doubleValue();
double label =
((Number) dataPoint.getField(getLabelCol()))
.doubleValue();
Preconditions.checkState(
Double.compare(0.0, label) == 0
|| Double.compare(1.0, label) == 0,
"LinearSVC only supports binary classification. But detected label: %s.",
label);
DenseVector features =
((Vector) dataPoint.getField(getFeaturesCol()))
.toDense();
return new LabeledPointWithWeight(features, label, weight);
});
DataStream<DenseVector> initModelData =
DataStreamUtils.reduce(
trainData.map(x -> x.getFeatures().size()),
(ReduceFunction<Integer>)
(t0, t1) -> {
Preconditions.checkState(
t0.equals(t1),
"The training data should all have same dimensions.");
return t0;
})
.map(DenseVector::new);
Optimizer optimizer =
new SGD(
getMaxIter(),
getLearningRate(),
getGlobalBatchSize(),
getTol(),
getReg(),
getElasticNet());
DataStream<DenseVector> rawModelData =
optimizer.optimize(initModelData, trainData, HingeLoss.INSTANCE);
DataStream<LinearSVCModelData> modelData = rawModelData.map(LinearSVCModelData::new);
LinearSVCModel model = new LinearSVCModel().setModelData(tEnv.fromDataStream(modelData));
ParamUtils.updateExistingParams(model, paramMap);
return model;
}