public void process()

in flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java [123:188]


        public void process(
                ProcessAllWindowFunction<Row, StandardScalerModelData, W>.Context context,
                Iterable<Row> iterable,
                Collector<StandardScalerModelData> collector)
                throws Exception {
            ListState<DenseVector> sumState =
                    context.globalState()
                            .getListState(
                                    new ListStateDescriptor<>(
                                            "sumState", DenseVectorTypeInfo.INSTANCE));
            ListState<DenseVector> squaredSumState =
                    context.globalState()
                            .getListState(
                                    new ListStateDescriptor<>(
                                            "squaredSumState", DenseVectorTypeInfo.INSTANCE));
            ListState<Long> numElementsState =
                    context.globalState()
                            .getListState(
                                    new ListStateDescriptor<>("numElementsState", Types.LONG));
            ListState<Long> modelVersionState =
                    context.globalState()
                            .getListState(
                                    new ListStateDescriptor<>("modelVersionState", Types.LONG));
            DenseVector sum =
                    OperatorStateUtils.getUniqueElement(sumState, "sumState").orElse(null);
            DenseVector squaredSum =
                    OperatorStateUtils.getUniqueElement(squaredSumState, "squaredSumState")
                            .orElse(null);
            long numElements =
                    OperatorStateUtils.getUniqueElement(numElementsState, "numElementsState")
                            .orElse(0L);
            long modelVersion =
                    OperatorStateUtils.getUniqueElement(modelVersionState, "modelVersionState")
                            .orElse(0L);

            long numElementsBefore = numElements;
            for (Row element : iterable) {
                Vector inputVec =
                        ((Vector) Objects.requireNonNull(element.getField(inputCol))).clone();
                if (numElements == 0) {
                    sum = new DenseVector(inputVec.size());
                    squaredSum = new DenseVector(inputVec.size());
                }
                BLAS.axpy(1, inputVec, sum);
                BLAS.hDot(inputVec, inputVec);
                BLAS.axpy(1, inputVec, squaredSum);
                numElements++;
            }
            if (numElements - numElementsBefore > 0) {
                long currentEventTime =
                        isEventTimeBasedTraining ? context.window().maxTimestamp() : Long.MAX_VALUE;
                collector.collect(
                        buildModelData(
                                numElements,
                                sum.clone(),
                                squaredSum.clone(),
                                modelVersion,
                                currentEventTime));

                sumState.update(Collections.singletonList(sum));
                squaredSumState.update(Collections.singletonList(squaredSum));
                numElementsState.update(Collections.singletonList(numElements));
                modelVersion++;
                modelVersionState.update(Collections.singletonList(modelVersion));
            }
        }