in flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java [91:139]
public VectorIndexerModel fit(Table... inputs) {
Preconditions.checkArgument(inputs.length == 1);
int maxCategories = getMaxCategories();
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
DataStream<List<Double>[]> localDistinctDoubles =
tEnv.toDataStream(inputs[0])
.transform(
"computeDistinctDoublesOperator",
Types.OBJECT_ARRAY(Types.LIST(Types.DOUBLE)),
new ComputeDistinctDoublesOperator(getInputCol(), maxCategories));
DataStream<List<Double>[]> distinctDoubles =
DataStreamUtils.reduce(
localDistinctDoubles,
(ReduceFunction<List<Double>[]>)
(value1, value2) -> {
for (int i = 0; i < value1.length; i++) {
if (value1[i] == null || value2[i] == null) {
value1[i] = null;
} else {
HashSet<Double> tmp = new HashSet<>(value1[i]);
tmp.addAll(value2[i]);
value1[i] = new ArrayList<>(tmp);
}
}
return value1;
});
DataStream<VectorIndexerModelData> modelData =
distinctDoubles.map(
new ModelGenerator(maxCategories), VectorIndexerModelData.TYPE_INFO);
modelData.getTransformation().setParallelism(1);
Schema schema =
Schema.newBuilder()
.column(
"categoryMaps",
DataTypes.MAP(
DataTypes.INT(),
DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.INT())))
.build();
VectorIndexerModel model =
new VectorIndexerModel().setModelData(tEnv.fromDataStream(modelData, schema));
ParamUtils.updateExistingParams(model, paramMap);
return model;
}