public LinearSVCModel fit()

in flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java [58:113]


    public LinearSVCModel fit(Table... inputs) {
        Preconditions.checkArgument(inputs.length == 1);
        StreamTableEnvironment tEnv =
                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();

        DataStream<LabeledPointWithWeight> trainData =
                tEnv.toDataStream(inputs[0])
                        .map(
                                dataPoint -> {
                                    double weight =
                                            getWeightCol() == null
                                                    ? 1.0
                                                    : ((Number) dataPoint.getField(getWeightCol()))
                                                            .doubleValue();
                                    double label =
                                            ((Number) dataPoint.getField(getLabelCol()))
                                                    .doubleValue();
                                    Preconditions.checkState(
                                            Double.compare(0.0, label) == 0
                                                    || Double.compare(1.0, label) == 0,
                                            "LinearSVC only supports binary classification. But detected label: %s.",
                                            label);
                                    DenseVector features =
                                            ((Vector) dataPoint.getField(getFeaturesCol()))
                                                    .toDense();
                                    return new LabeledPointWithWeight(features, label, weight);
                                });

        DataStream<DenseVector> initModelData =
                DataStreamUtils.reduce(
                                trainData.map(x -> x.getFeatures().size()),
                                (ReduceFunction<Integer>)
                                        (t0, t1) -> {
                                            Preconditions.checkState(
                                                    t0.equals(t1),
                                                    "The training data should all have same dimensions.");
                                            return t0;
                                        })
                        .map(DenseVector::new);

        Optimizer optimizer =
                new SGD(
                        getMaxIter(),
                        getLearningRate(),
                        getGlobalBatchSize(),
                        getTol(),
                        getReg(),
                        getElasticNet());
        DataStream<DenseVector> rawModelData =
                optimizer.optimize(initModelData, trainData, HingeLoss.INSTANCE);

        DataStream<LinearSVCModelData> modelData = rawModelData.map(LinearSVCModelData::new);
        LinearSVCModel model = new LinearSVCModel().setModelData(tEnv.fromDataStream(modelData));
        ParamUtils.updateExistingParams(model, paramMap);
        return model;
    }