double maximumLikelihoodShape()

in lib/maths/common/CGammaRateConjugate.cc [144:294]


double maximumLikelihoodShape(double oldShape,
                              const TMeanAccumulator& oldLogMean,
                              const TMeanAccumulator& newLogMean,
                              const TMeanVarAccumulator& oldMoments,
                              const TMeanVarAccumulator& newMoments) {
    if (CBasicStatistics::count(newMoments) < NON_INFORMATIVE_COUNT) {
        return oldShape;
    }

    static const double EPS = 1e-3;

    // Use large maximum growth factors for root bracketing because
    // overshooting is less costly than extra iterations to bracket
    // the root since we get cubic convergence in the solving loop.
    // The derivative of the digamma function is monotone decreasing
    // so we use a higher maximum growth factor on the upside.
    static const double MIN_DOWN_FACTOR = 0.25;
    static const double MAX_UP_FACTOR = 8.0;

    std::size_t maxIterations = 20;

    double oldNumber = CBasicStatistics::count(oldMoments);
    double oldMean = CBasicStatistics::mean(oldMoments);

    double oldTarget = 0.0;
    if (oldNumber * oldMean > 0.0) {
        oldTarget = std::log(oldNumber * oldMean) - CBasicStatistics::mean(oldLogMean);
    }

    double newNumber = CBasicStatistics::count(newMoments);
    double newMean = CBasicStatistics::mean(newMoments);

    if (newNumber * newMean == 0.0) {
        return 0.0;
    }
    double target = std::log(newNumber * newMean) - CBasicStatistics::mean(newLogMean);

    // Fall back to method of moments if maximum-likelihood fails.
    double bestGuess = 1.0;
    if (CBasicStatistics::variance(newMoments) > 0.0) {
        bestGuess = newMean * newMean / CBasicStatistics::variance(newMoments);
    }

    // If we've estimated the shape before the old shape will typically
    // be a very good initial estimate. Otherwise, use the best guess.
    double x0 = bestGuess;
    if (oldNumber > NON_INFORMATIVE_COUNT) {
        x0 = oldShape;
    }

    TDoubleDoublePr bracket(x0, x0);

    double downFactor = 0.8;
    double upFactor = 1.4;

    if (oldNumber > NON_INFORMATIVE_COUNT) {
        // Compute, very approximately, minus the gradient of the function
        // at the old shape. We just use the chord from the origin to the
        // target value and truncate its value so the bracketing loop is
        // well behaved.
        double gradient = 1.0;
        if (oldShape > 0.0) {
            gradient = CTools::truncate(oldTarget / oldShape, EPS, 1.0);
        }

        // Choose the growth factors so we will typically bracket the root
        // in one iteration and not overshoot too much. Again we truncate
        // the values so that bracketing loop is well behaved.
        double dTarget = std::fabs(target - oldTarget);
        downFactor = CTools::truncate(1.0 - 2.0 * dTarget / gradient,
                                      MIN_DOWN_FACTOR, 1.0 - EPS);
        upFactor = CTools::truncate(1.0 + 2.0 * dTarget / gradient, 1.0 + EPS, MAX_UP_FACTOR);
    }

    CLikelihoodDerivativeFunction derivative(newNumber, target);
    double f0 = 0.0;
    TDoubleDoublePr fBracket(f0, f0);

    try {
        fBracket.first = fBracket.second = f0 = derivative(x0);

        if (f0 == 0.0) {
            // We're done.
            return x0;
        }

        // The target function is monotone decreasing. The rate at which we
        // change the down and up factors in this loop has been determined
        // empirically to give a good expected total number of evaluations
        // of the likelihood derivative function across a range of different
        // process gamma shapes and rates. In particular, the mean total
        // number of evaluations used by this function is around five.
        for (/**/; maxIterations > 0; --maxIterations) {
            if (fBracket.first < 0.0) {
                bracket.second = bracket.first;
                fBracket.second = fBracket.first;

                bracket.first *= downFactor;
                fBracket.first = derivative(bracket.first);

                downFactor = std::max(0.8 * downFactor, MIN_DOWN_FACTOR);
            } else if (fBracket.second > 0.0) {
                bracket.first = bracket.second;
                fBracket.first = fBracket.second;

                bracket.second *= upFactor;
                fBracket.second = derivative(bracket.second);

                upFactor = std::min(1.4 * upFactor, MAX_UP_FACTOR);
            } else {
                break;
            }
        }
    } catch (const std::exception& e) {
        LOG_ERROR(<< "Failed to bracket root: " << e.what() << ", newNumber = " << newNumber
                  << ", newMean = " << newMean << ", newLogMean = " << newLogMean
                  << ", x0 = " << x0 << ", f(x0) = " << f0 << ", bracket = " << bracket
                  << ", f(bracket) = " << fBracket << ", bestGuess = " << bestGuess);
        return bestGuess;
    }

    if (maxIterations == 0) {
        LOG_TRACE(<< "Failed to bracket root:"
                  << " newNumber = " << newNumber << ", newMean = " << newMean
                  << ", newLogMean = " << newLogMean << ", x0 = " << x0
                  << ", f(x0) = " << f0 << ", bracket = " << bracket
                  << ", f(bracket) = " << fBracket << ", bestGuess = " << bestGuess);
        return bestGuess;
    }

    LOG_TRACE(<< "newNumber = " << newNumber << ", newMean = " << newMean
              << ", newLogMean = " << newLogMean << ", oldTarget = " << oldTarget
              << ", target = " << target << ", upFactor = " << upFactor
              << ", downFactor = " << downFactor << ", x0 = " << x0 << ", f(x0) = " << f0
              << ", bracket = " << bracket << ", f(bracket) = " << fBracket);

    try {
        CEqualWithTolerance<double> tolerance(CToleranceTypes::E_AbsoluteTolerance, EPS * x0);
        CSolvers::solve(bracket.first, bracket.second, fBracket.first, fBracket.second,
                        derivative, maxIterations, tolerance, bestGuess);
    } catch (const std::exception& e) {
        LOG_ERROR(<< "Failed to solve: " << e.what() << ", newNumber = " << newNumber
                  << ", x0 = " << x0 << ", f(x0) = " << f0 << ", bracket = " << bracket
                  << ", f(bracket) = " << fBracket << ", bestGuess = " << bestGuess);
        return bestGuess;
    }

    LOG_TRACE(<< "bracket = " << bracket);

    return (bracket.first + bracket.second) / 2.0;
}