in flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelector.java [103:154]
public UnivariateFeatureSelectorModel fit(Table... inputs) {
Preconditions.checkArgument(inputs.length == 1);
final String featuresCol = getFeaturesCol();
final String labelCol = getLabelCol();
final String featureType = getFeatureType();
final String labelType = getLabelType();
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
Table output;
if (CATEGORICAL.equals(featureType) && CATEGORICAL.equals(labelType)) {
output =
new ChiSqTest()
.setFeaturesCol(featuresCol)
.setLabelCol(labelCol)
.setFlatten(true)
.transform(inputs[0])[0];
} else if (CONTINUOUS.equals(featureType) && CATEGORICAL.equals(labelType)) {
output =
new ANOVATest()
.setFeaturesCol(featuresCol)
.setLabelCol(labelCol)
.setFlatten(true)
.transform(inputs[0])[0];
} else if (CONTINUOUS.equals(featureType) && CONTINUOUS.equals(labelType)) {
output =
new FValueTest()
.setFeaturesCol(featuresCol)
.setLabelCol(labelCol)
.setFlatten(true)
.transform(inputs[0])[0];
} else {
throw new IllegalArgumentException(
String.format(
"Unsupported combination: featureType=%s, labelType=%s.",
featureType, labelType));
}
DataStream<UnivariateFeatureSelectorModelData> modelData =
tEnv.toDataStream(output)
.transform(
"selectIndicesFromPValues",
TypeInformation.of(UnivariateFeatureSelectorModelData.class),
new SelectIndicesFromPValuesOperator(
getSelectionMode(), getActualSelectionThreshold()))
.setParallelism(1);
UnivariateFeatureSelectorModel model =
new UnivariateFeatureSelectorModel().setModelData(tEnv.fromDataStream(modelData));
ParamUtils.updateExistingParams(model, getParamMap());
return model;
}