in flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java [47:91]
public static double regularize(
DenseVector coefficient,
final double reg,
final double elasticNet,
final double learningRate) {
if (Double.compare(reg, 0) == 0) {
return 0;
} else if (Double.compare(elasticNet, 0) == 0) {
// Only L2 regularization.
double loss = reg / 2 * BLAS.norm2(coefficient);
BLAS.scal(1 - learningRate * reg, coefficient);
return loss;
} else if (Double.compare(elasticNet, 1) == 0) {
// Only L1 regularization.
double loss = 0;
double[] coefficientArray = coefficient.values;
for (int i = 0; i < coefficientArray.length; i++) {
if (Double.compare(coefficientArray[i], 0) == 0) {
continue;
}
loss += elasticNet * reg * Math.signum(coefficientArray[i]);
coefficientArray[i] -=
learningRate * elasticNet * reg * Math.signum(coefficientArray[i]);
}
return loss;
} else {
// Both L1 and L2 are not zero.
double loss = 0;
double[] coefficientArray = coefficient.values;
for (int i = 0; i < coefficientArray.length; i++) {
loss +=
elasticNet * reg * Math.signum(coefficientArray[i])
+ (1 - elasticNet)
* (reg / 2)
* coefficientArray[i]
* coefficientArray[i];
coefficientArray[i] -=
(learningRate
* (elasticNet * reg * Math.signum(coefficientArray[i])
+ (1 - elasticNet) * reg * coefficientArray[i]));
}
return loss;
}
}