in flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java [149:187]
public Row map(Row row) throws Exception {
if (vocabulary == null) {
CountVectorizerModelData modelData =
(CountVectorizerModelData)
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
vocabulary = new HashMap<>();
IntStream.range(0, modelData.vocabulary.length)
.forEach(i -> vocabulary.put(modelData.vocabulary[i], i));
}
String[] document = (String[]) row.getField(inputCol);
double[] termCounts = new double[vocabulary.size()];
for (String word : document) {
if (vocabulary.containsKey(word)) {
termCounts[vocabulary.get(word)] += 1;
}
}
double actualMinTF = minTF >= 1.0 ? minTF : document.length * minTF;
List<Integer> indices = new ArrayList<>();
List<Double> values = new ArrayList<>();
for (int i = 0; i < termCounts.length; i++) {
if (termCounts[i] >= actualMinTF) {
indices.add(i);
if (binary) {
values.add(1.0);
} else {
values.add(termCounts[i]);
}
}
}
SparseVector outputVec =
Vectors.sparse(
termCounts.length,
indices.stream().mapToInt(i -> i).toArray(),
values.stream().mapToDouble(i -> i).toArray());
return RowUtils.append(row, outputVec);
}