public static double regularize()

in flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java [47:91]


    public static double regularize(
            DenseVector coefficient,
            final double reg,
            final double elasticNet,
            final double learningRate) {

        if (Double.compare(reg, 0) == 0) {
            return 0;
        } else if (Double.compare(elasticNet, 0) == 0) {
            // Only L2 regularization.
            double loss = reg / 2 * BLAS.norm2(coefficient);
            BLAS.scal(1 - learningRate * reg, coefficient);
            return loss;
        } else if (Double.compare(elasticNet, 1) == 0) {
            // Only L1 regularization.
            double loss = 0;
            double[] coefficientArray = coefficient.values;
            for (int i = 0; i < coefficientArray.length; i++) {
                if (Double.compare(coefficientArray[i], 0) == 0) {
                    continue;
                }
                loss += elasticNet * reg * Math.signum(coefficientArray[i]);
                coefficientArray[i] -=
                        learningRate * elasticNet * reg * Math.signum(coefficientArray[i]);
            }
            return loss;
        } else {
            // Both L1 and L2 are not zero.
            double loss = 0;
            double[] coefficientArray = coefficient.values;
            for (int i = 0; i < coefficientArray.length; i++) {
                loss +=
                        elasticNet * reg * Math.signum(coefficientArray[i])
                                + (1 - elasticNet)
                                        * (reg / 2)
                                        * coefficientArray[i]
                                        * coefficientArray[i];
                coefficientArray[i] -=
                        (learningRate
                                * (elasticNet * reg * Math.signum(coefficientArray[i])
                                        + (1 - elasticNet) * reg * coefficientArray[i]));
            }
            return loss;
        }
    }