in flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java [63:112]
public ImputerModel fit(Table... inputs) {
Preconditions.checkArgument(inputs.length == 1);
Preconditions.checkArgument(
getInputCols().length == getOutputCols().length,
"Num of input columns and output columns are inconsistent.");
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
DataStream<Row> inputData = tEnv.toDataStream(inputs[0]);
DataStream<ImputerModelData> modelData;
switch (getStrategy()) {
case MEAN:
modelData =
DataStreamUtils.aggregate(
inputData,
new MeanStrategyAggregator(getInputCols(), getMissingValue()),
Types.MAP(Types.STRING, Types.TUPLE(Types.DOUBLE, Types.LONG)),
ImputerModelData.TYPE_INFO);
break;
case MEDIAN:
modelData =
DataStreamUtils.aggregate(
inputData,
new MedianStrategyAggregator(
getInputCols(), getMissingValue(), getRelativeError()),
Types.MAP(Types.STRING, TypeInformation.of(QuantileSummary.class)),
ImputerModelData.TYPE_INFO);
break;
case MOST_FREQUENT:
modelData =
DataStreamUtils.aggregate(
inputData,
new MostFrequentStrategyAggregator(
getInputCols(), getMissingValue()),
Types.MAP(Types.STRING, Types.MAP(Types.DOUBLE, Types.LONG)),
ImputerModelData.TYPE_INFO);
break;
default:
throw new RuntimeException("Unsupported strategy of Imputer: " + getStrategy());
}
Schema schema =
Schema.newBuilder()
.column("surrogates", DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE()))
.build();
ImputerModel model =
new ImputerModel().setModelData(tEnv.fromDataStream(modelData, schema));
ParamUtils.updateExistingParams(model, getParamMap());
return model;
}