in flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java [123:188]
public void process(
ProcessAllWindowFunction<Row, StandardScalerModelData, W>.Context context,
Iterable<Row> iterable,
Collector<StandardScalerModelData> collector)
throws Exception {
ListState<DenseVector> sumState =
context.globalState()
.getListState(
new ListStateDescriptor<>(
"sumState", DenseVectorTypeInfo.INSTANCE));
ListState<DenseVector> squaredSumState =
context.globalState()
.getListState(
new ListStateDescriptor<>(
"squaredSumState", DenseVectorTypeInfo.INSTANCE));
ListState<Long> numElementsState =
context.globalState()
.getListState(
new ListStateDescriptor<>("numElementsState", Types.LONG));
ListState<Long> modelVersionState =
context.globalState()
.getListState(
new ListStateDescriptor<>("modelVersionState", Types.LONG));
DenseVector sum =
OperatorStateUtils.getUniqueElement(sumState, "sumState").orElse(null);
DenseVector squaredSum =
OperatorStateUtils.getUniqueElement(squaredSumState, "squaredSumState")
.orElse(null);
long numElements =
OperatorStateUtils.getUniqueElement(numElementsState, "numElementsState")
.orElse(0L);
long modelVersion =
OperatorStateUtils.getUniqueElement(modelVersionState, "modelVersionState")
.orElse(0L);
long numElementsBefore = numElements;
for (Row element : iterable) {
Vector inputVec =
((Vector) Objects.requireNonNull(element.getField(inputCol))).clone();
if (numElements == 0) {
sum = new DenseVector(inputVec.size());
squaredSum = new DenseVector(inputVec.size());
}
BLAS.axpy(1, inputVec, sum);
BLAS.hDot(inputVec, inputVec);
BLAS.axpy(1, inputVec, squaredSum);
numElements++;
}
if (numElements - numElementsBefore > 0) {
long currentEventTime =
isEventTimeBasedTraining ? context.window().maxTimestamp() : Long.MAX_VALUE;
collector.collect(
buildModelData(
numElements,
sum.clone(),
squaredSum.clone(),
modelVersion,
currentEventTime));
sumState.update(Collections.singletonList(sum));
squaredSumState.update(Collections.singletonList(squaredSum));
numElementsState.update(Collections.singletonList(numElements));
modelVersion++;
modelVersionState.update(Collections.singletonList(modelVersion));
}
}