public StringIndexerModel fit()

in flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java [98:152]


    public StringIndexerModel fit(Table... inputs) {
        Preconditions.checkArgument(inputs.length == 1);
        String[] inputCols = getInputCols();
        String[] outputCols = getOutputCols();
        Preconditions.checkArgument(inputCols.length == outputCols.length);
        if (getMaxIndexNum() < Integer.MAX_VALUE) {
            Preconditions.checkArgument(
                    getStringOrderType().equals(StringIndexerParams.FREQUENCY_DESC_ORDER),
                    "Setting "
                            + MAX_INDEX_NUM.name
                            + " smaller than INT.MAX only works when "
                            + STRING_ORDER_TYPE.name
                            + " is set as "
                            + StringIndexerParams.FREQUENCY_DESC_ORDER
                            + ".");
        }
        StreamTableEnvironment tEnv =
                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();

        DataStream<Map<String, Long>[]> localCountedString =
                tEnv.toDataStream(inputs[0])
                        .transform(
                                "countStringOperator",
                                Types.OBJECT_ARRAY(Types.MAP(Types.STRING, Types.LONG)),
                                new CountStringOperator(inputCols));

        DataStream<Map<String, Long>[]> countedString =
                DataStreamUtils.reduce(
                        localCountedString,
                        (ReduceFunction<Map<String, Long>[]>)
                                (value1, value2) -> {
                                    for (int i = 0; i < value1.length; i++) {
                                        for (Entry<String, Long> stringAndCnt :
                                                value2[i].entrySet()) {
                                            value1[i].compute(
                                                    stringAndCnt.getKey(),
                                                    (k, v) ->
                                                            (v == null
                                                                    ? stringAndCnt.getValue()
                                                                    : v + stringAndCnt.getValue()));
                                        }
                                    }
                                    return value1;
                                },
                        Types.OBJECT_ARRAY(Types.MAP(Types.STRING, Types.LONG)));

        DataStream<StringIndexerModelData> modelData =
                countedString.map(new ModelGenerator(getStringOrderType(), getMaxIndexNum()));
        modelData.getTransformation().setParallelism(1);

        StringIndexerModel model =
                new StringIndexerModel().setModelData(tEnv.fromDataStream(modelData));
        ParamUtils.updateExistingParams(model, paramMap);
        return model;
    }