private void calculateGradient()

in flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java [351:403]


        private void calculateGradient() throws Exception {
            if (!modelDataState.get().iterator().hasNext()
                    || !localBatchDataState.get().iterator().hasNext()) {
                return;
            }
            DenseVector modelData =
                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
            modelDataState.clear();

            List<Row[]> pointsList = IteratorUtils.toList(localBatchDataState.get().iterator());
            Row[] points = pointsList.remove(0);
            localBatchDataState.update(pointsList);

            for (Row point : points) {
                Vector vec = point.getFieldAs(0);
                double label = point.getFieldAs(1);
                double weight = point.getArity() == 2 ? 1.0 : point.getFieldAs(2);
                if (gradient == null) {
                    gradient = new double[vec.size()];
                    weightSum = new double[gradient.length];
                }
                double p = BLAS.dot(modelData, vec);
                p = 1 / (1 + Math.exp(-p));
                if (vec instanceof DenseVector) {
                    DenseVector dvec = (DenseVector) vec;
                    for (int i = 0; i < modelData.size(); ++i) {
                        gradient[i] += (p - label) * dvec.values[i];
                        weightSum[i] += 1.0;
                    }
                } else {
                    SparseVector svec = (SparseVector) vec;
                    for (int i = 0; i < svec.indices.length; ++i) {
                        int idx = svec.indices[i];
                        gradient[idx] += (p - label) * svec.values[i];
                        weightSum[idx] += weight;
                    }
                }
            }

            if (points.length > 0) {
                output.collect(
                        new StreamRecord<>(
                                new DenseVector[] {
                                    new DenseVector(gradient),
                                    new DenseVector(weightSum),
                                    (getRuntimeContext().getIndexOfThisSubtask() == 0)
                                            ? modelData
                                            : null
                                }));
            }
            Arrays.fill(gradient, 0.0);
            Arrays.fill(weightSum, 0.0);
        }