in flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java [98:152]
public StringIndexerModel fit(Table... inputs) {
Preconditions.checkArgument(inputs.length == 1);
String[] inputCols = getInputCols();
String[] outputCols = getOutputCols();
Preconditions.checkArgument(inputCols.length == outputCols.length);
if (getMaxIndexNum() < Integer.MAX_VALUE) {
Preconditions.checkArgument(
getStringOrderType().equals(StringIndexerParams.FREQUENCY_DESC_ORDER),
"Setting "
+ MAX_INDEX_NUM.name
+ " smaller than INT.MAX only works when "
+ STRING_ORDER_TYPE.name
+ " is set as "
+ StringIndexerParams.FREQUENCY_DESC_ORDER
+ ".");
}
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
DataStream<Map<String, Long>[]> localCountedString =
tEnv.toDataStream(inputs[0])
.transform(
"countStringOperator",
Types.OBJECT_ARRAY(Types.MAP(Types.STRING, Types.LONG)),
new CountStringOperator(inputCols));
DataStream<Map<String, Long>[]> countedString =
DataStreamUtils.reduce(
localCountedString,
(ReduceFunction<Map<String, Long>[]>)
(value1, value2) -> {
for (int i = 0; i < value1.length; i++) {
for (Entry<String, Long> stringAndCnt :
value2[i].entrySet()) {
value1[i].compute(
stringAndCnt.getKey(),
(k, v) ->
(v == null
? stringAndCnt.getValue()
: v + stringAndCnt.getValue()));
}
}
return value1;
},
Types.OBJECT_ARRAY(Types.MAP(Types.STRING, Types.LONG)));
DataStream<StringIndexerModelData> modelData =
countedString.map(new ModelGenerator(getStringOrderType(), getMaxIndexNum()));
modelData.getTransformation().setParallelism(1);
StringIndexerModel model =
new StringIndexerModel().setModelData(tEnv.fromDataStream(modelData));
ParamUtils.updateExistingParams(model, paramMap);
return model;
}