in ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java [103:146]
private void createOptimiser() {
LinearRegressionParams.OptimizerType optimizerType = Optional.ofNullable(parameters.getOptimizerType()).orElse(DEFAULT_OPTIMIZER_TYPE);
Double learningRate = Optional.ofNullable(parameters.getLearningRate()).orElse(DEFAULT_LEARNING_RATE);
Double momentumFactor = Optional.ofNullable(parameters.getMomentumFactor()).orElse(DEFAULT_MOMENTUM_FACTOR);
Double epsilon = Optional.ofNullable(parameters.getEpsilon()).orElse(DEFAULT_EPSILON);
Double beta1 = Optional.ofNullable(parameters.getBeta1()).orElse(DEFAULT_BETA1);
Double beta2 = Optional.ofNullable(parameters.getBeta2()).orElse(DEFAULT_BETA2);
LinearRegressionParams.MomentumType momentumType = Optional.ofNullable(parameters.getMomentumType()).orElse(DEFAULT_MOMENTUM_TYPE);
Double decayRate = Optional.ofNullable(parameters.getDecayRate()).orElse(DEFAULT_DECAY_RATE);
SGD.Momentum momentum;
switch (momentumType) {
case NESTEROV:
momentum = SGD.Momentum.NESTEROV;
break;
default:
momentum = SGD.Momentum.STANDARD;
break;
}
switch (optimizerType) {
case LINEAR_DECAY_SGD:
optimiser = SGD.getLinearDecaySGD(learningRate, momentumFactor, momentum);
break;
case SQRT_DECAY_SGD:
optimiser = SGD.getSqrtDecaySGD(learningRate, momentumFactor, momentum);
break;
case ADA_GRAD:
optimiser = new AdaGrad(learningRate, epsilon);
break;
case ADA_DELTA:
optimiser = new AdaDelta(momentumFactor, epsilon);
break;
case ADAM:
optimiser = new Adam(learningRate, beta1, beta2, epsilon);
break;
case RMS_PROP:
optimiser = new RMSProp(learningRate, momentumFactor, epsilon, decayRate);
break;
default:
//Use default SGD with a constant learning rate.
optimiser = SGD.getSimpleSGD(learningRate, momentumFactor, momentum);
break;
}
}