in flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java [291:321]
public void processElement(StreamRecord<DenseVector[]> streamRecord) throws Exception {
DenseVector[] gradientInfo = streamRecord.getValue();
double[] coefficient = gradientInfo[2].values;
double[] g = gradientInfo[0].values;
for (int i = 0; i < g.length; ++i) {
if (gradientInfo[1].values[i] != 0.0) {
g[i] = g[i] / gradientInfo[1].values[i];
}
}
if (zParam == null) {
zParam = new double[g.length];
nParam = new double[g.length];
nParamState.add(nParam);
zParamState.add(zParam);
}
for (int i = 0; i < zParam.length; ++i) {
double sigma = (Math.sqrt(nParam[i] + g[i] * g[i]) - Math.sqrt(nParam[i])) / alpha;
zParam[i] += g[i] - sigma * coefficient[i];
nParam[i] += g[i] * g[i];
if (Math.abs(zParam[i]) <= l1) {
coefficient[i] = 0.0;
} else {
coefficient[i] =
((zParam[i] < 0 ? -1 : 1) * l1 - zParam[i])
/ ((beta + Math.sqrt(nParam[i])) / alpha + l2);
}
}
output.collect(new StreamRecord<>(new DenseVector(coefficient)));
}