in flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java [351:403]
private void calculateGradient() throws Exception {
if (!modelDataState.get().iterator().hasNext()
|| !localBatchDataState.get().iterator().hasNext()) {
return;
}
DenseVector modelData =
OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
modelDataState.clear();
List<Row[]> pointsList = IteratorUtils.toList(localBatchDataState.get().iterator());
Row[] points = pointsList.remove(0);
localBatchDataState.update(pointsList);
for (Row point : points) {
Vector vec = point.getFieldAs(0);
double label = point.getFieldAs(1);
double weight = point.getArity() == 2 ? 1.0 : point.getFieldAs(2);
if (gradient == null) {
gradient = new double[vec.size()];
weightSum = new double[gradient.length];
}
double p = BLAS.dot(modelData, vec);
p = 1 / (1 + Math.exp(-p));
if (vec instanceof DenseVector) {
DenseVector dvec = (DenseVector) vec;
for (int i = 0; i < modelData.size(); ++i) {
gradient[i] += (p - label) * dvec.values[i];
weightSum[i] += 1.0;
}
} else {
SparseVector svec = (SparseVector) vec;
for (int i = 0; i < svec.indices.length; ++i) {
int idx = svec.indices[i];
gradient[idx] += (p - label) * svec.values[i];
weightSum[idx] += weight;
}
}
}
if (points.length > 0) {
output.collect(
new StreamRecord<>(
new DenseVector[] {
new DenseVector(gradient),
new DenseVector(weightSum),
(getRuntimeContext().getIndexOfThisSubtask() == 0)
? modelData
: null
}));
}
Arrays.fill(gradient, 0.0);
Arrays.fill(weightSum, 0.0);
}