private void createOptimiser()

in ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java [103:146]


    private void createOptimiser() {
        LinearRegressionParams.OptimizerType optimizerType = Optional.ofNullable(parameters.getOptimizerType()).orElse(DEFAULT_OPTIMIZER_TYPE);
        Double learningRate = Optional.ofNullable(parameters.getLearningRate()).orElse(DEFAULT_LEARNING_RATE);
        Double momentumFactor = Optional.ofNullable(parameters.getMomentumFactor()).orElse(DEFAULT_MOMENTUM_FACTOR);
        Double epsilon = Optional.ofNullable(parameters.getEpsilon()).orElse(DEFAULT_EPSILON);
        Double beta1 = Optional.ofNullable(parameters.getBeta1()).orElse(DEFAULT_BETA1);
        Double beta2 = Optional.ofNullable(parameters.getBeta2()).orElse(DEFAULT_BETA2);
        LinearRegressionParams.MomentumType momentumType = Optional.ofNullable(parameters.getMomentumType()).orElse(DEFAULT_MOMENTUM_TYPE);
        Double decayRate = Optional.ofNullable(parameters.getDecayRate()).orElse(DEFAULT_DECAY_RATE);

        SGD.Momentum momentum;
        switch (momentumType) {
            case NESTEROV:
                momentum = SGD.Momentum.NESTEROV;
                break;
            default:
                momentum = SGD.Momentum.STANDARD;
                break;
        }
        switch (optimizerType) {
            case LINEAR_DECAY_SGD:
                optimiser = SGD.getLinearDecaySGD(learningRate, momentumFactor, momentum);
                break;
            case SQRT_DECAY_SGD:
                optimiser = SGD.getSqrtDecaySGD(learningRate, momentumFactor, momentum);
                break;
            case ADA_GRAD:
                optimiser = new AdaGrad(learningRate, epsilon);
                break;
            case ADA_DELTA:
                optimiser = new AdaDelta(momentumFactor, epsilon);
                break;
            case ADAM:
                optimiser = new Adam(learningRate, beta1, beta2, epsilon);
                break;
            case RMS_PROP:
                optimiser = new RMSProp(learningRate, momentumFactor, epsilon, decayRate);
                break;
            default:
                //Use default SGD with a constant learning rate.
                optimiser = SGD.getSimpleSGD(learningRate, momentumFactor, momentum);
                break;
        }
    }