in core/src/main/java/org/apache/mahout/math/solver/LSMR.java [170:474]
public Vector solve(Matrix A, Vector b) {
/*
% Initialize.
hdg1 = ' itn x(1) norm r norm A''r';
hdg2 = ' compatible LS norm A cond A';
pfreq = 20; % print frequency (for repeating the heading)
pcount = 0; % print counter
% Determine dimensions m and n, and
% form the first vectors u and v.
% These satisfy beta*u = b, alpha*v = A'u.
*/
log.debug(" itn x(1) norm r norm A'r");
log.debug(" compatible LS norm A cond A");
Matrix transposedA = A.transpose();
Vector u = b;
double beta = u.norm(2);
if (beta > 0) {
u = u.divide(beta);
}
Vector v = transposedA.times(u);
int m = A.numRows();
int n = A.numCols();
int minDim = Math.min(m, n);
if (iterationLimit == -1) {
iterationLimit = minDim;
}
if (log.isDebugEnabled()) {
log.debug("LSMR - Least-squares solution of Ax = b, based on Matlab Version 1.02, 14 Apr 2010, "
+ "Mahout version {}", getClass().getPackage().getImplementationVersion());
log.debug(String.format("The matrix A has %d rows and %d cols, lambda = %.4g, atol = %g, btol = %g",
m, n, lambda, aTolerance, bTolerance));
}
double alpha = v.norm(2);
if (alpha > 0) {
v.assign(Functions.div(alpha));
}
// Initialization for local reorthogonalization
localPointer = 0;
// Preallocate storage for storing the last few v_k. Since with
// orthogonal v_k's, Krylov subspace method would converge in not
// more iterations than the number of singular values, more
// space is not necessary.
localV = new Vector[Math.min(localSize, minDim)];
boolean localOrtho = false;
if (localSize > 0) {
localOrtho = true;
localV[0] = v;
}
// Initialize variables for 1st iteration.
iteration = 0;
double zetabar = alpha * beta;
double alphabar = alpha;
Vector h = v;
Vector hbar = zeros(n);
Vector x = zeros(n);
// Initialize variables for estimation of ||r||.
double betadd = beta;
// Initialize variables for estimation of ||A|| and cond(A)
double aNorm = alpha * alpha;
// Items for use in stopping rules.
double normb = beta;
double ctol = 0;
if (conditionLimit > 0) {
ctol = 1 / conditionLimit;
}
residualNorm = beta;
// Exit if b=0 or A'b = 0.
normalEquationResidual = alpha * beta;
if (normalEquationResidual == 0) {
return x;
}
// Heading for iteration log.
if (log.isDebugEnabled()) {
double test2 = alpha / beta;
// log.debug('{} {}', hdg1, hdg2);
log.debug("{} {}", iteration, x.get(0));
log.debug("{} {}", residualNorm, normalEquationResidual);
double test1 = 1;
log.debug("{} {}", test1, test2);
}
//------------------------------------------------------------------
// Main iteration loop.
//------------------------------------------------------------------
double rho = 1;
double rhobar = 1;
double cbar = 1;
double sbar = 0;
double betad = 0;
double rhodold = 1;
double tautildeold = 0;
double thetatilde = 0;
double zeta = 0;
double d = 0;
double maxrbar = 0;
double minrbar = 1.0e+100;
StopCode stop = StopCode.CONTINUE;
while (iteration <= iterationLimit && stop == StopCode.CONTINUE) {
iteration++;
// Perform the next step of the bidiagonalization to obtain the
// next beta, u, alpha, v. These satisfy the relations
// beta*u = A*v - alpha*u,
// alpha*v = A'*u - beta*v.
u = A.times(v).minus(u.times(alpha));
beta = u.norm(2);
if (beta > 0) {
u.assign(Functions.div(beta));
// store data for local-reorthogonalization of V
if (localOrtho) {
localVEnqueue(v);
}
v = transposedA.times(u).minus(v.times(beta));
// local-reorthogonalization of V
if (localOrtho) {
v = localVOrtho(v);
}
alpha = v.norm(2);
if (alpha > 0) {
v.assign(Functions.div(alpha));
}
}
// At this point, beta = beta_{k+1}, alpha = alpha_{k+1}.
// Construct rotation Qhat_{k,2k+1}.
double alphahat = Math.hypot(alphabar, lambda);
double chat = alphabar / alphahat;
double shat = lambda / alphahat;
// Use a plane rotation (Q_i) to turn B_i to R_i
double rhoold = rho;
rho = Math.hypot(alphahat, beta);
double c = alphahat / rho;
double s = beta / rho;
double thetanew = s * alpha;
alphabar = c * alpha;
// Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar
double rhobarold = rhobar;
double zetaold = zeta;
double thetabar = sbar * rho;
double rhotemp = cbar * rho;
rhobar = Math.hypot(cbar * rho, thetanew);
cbar = cbar * rho / rhobar;
sbar = thetanew / rhobar;
zeta = cbar * zetabar;
zetabar = -sbar * zetabar;
// Update h, h_hat, x.
hbar = h.minus(hbar.times(thetabar * rho / (rhoold * rhobarold)));
x.assign(hbar.times(zeta / (rho * rhobar)), Functions.PLUS);
h = v.minus(h.times(thetanew / rho));
// Estimate of ||r||.
// Apply rotation Qhat_{k,2k+1}.
double betaacute = chat * betadd;
double betacheck = -shat * betadd;
// Apply rotation Q_{k,k+1}.
double betahat = c * betaacute;
betadd = -s * betaacute;
// Apply rotation Qtilde_{k-1}.
// betad = betad_{k-1} here.
double thetatildeold = thetatilde;
double rhotildeold = Math.hypot(rhodold, thetabar);
double ctildeold = rhodold / rhotildeold;
double stildeold = thetabar / rhotildeold;
thetatilde = stildeold * rhobar;
rhodold = ctildeold * rhobar;
betad = -stildeold * betad + ctildeold * betahat;
// betad = betad_k here.
// rhodold = rhod_k here.
tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold;
double taud = (zeta - thetatilde * tautildeold) / rhodold;
d += betacheck * betacheck;
residualNorm = Math.sqrt(d + (betad - taud) * (betad - taud) + betadd * betadd);
// Estimate ||A||.
aNorm += beta * beta;
normA = Math.sqrt(aNorm);
aNorm += alpha * alpha;
// Estimate cond(A).
maxrbar = Math.max(maxrbar, rhobarold);
if (iteration > 1) {
minrbar = Math.min(minrbar, rhobarold);
}
condA = Math.max(maxrbar, rhotemp) / Math.min(minrbar, rhotemp);
// Test for convergence.
// Compute norms for convergence testing.
normalEquationResidual = Math.abs(zetabar);
xNorm = x.norm(2);
// Now use these norms to estimate certain other quantities,
// some of which will be small near a solution.
double test1 = residualNorm / normb;
double test2 = normalEquationResidual / (normA * residualNorm);
double test3 = 1 / condA;
double t1 = test1 / (1 + normA * xNorm / normb);
double rtol = bTolerance + aTolerance * normA * xNorm / normb;
// The following tests guard against extremely small values of
// atol, btol or ctol. (The user may have set any or all of
// the parameters atol, btol, conlim to 0.)
// The effect is equivalent to the normAl tests using
// atol = eps, btol = eps, conlim = 1/eps.
if (iteration > iterationLimit) {
stop = StopCode.ITERATION_LIMIT;
}
if (1 + test3 <= 1) {
stop = StopCode.CONDITION_MACHINE_TOLERANCE;
}
if (1 + test2 <= 1) {
stop = StopCode.LEAST_SQUARE_CONVERGED_MACHINE_TOLERANCE;
}
if (1 + t1 <= 1) {
stop = StopCode.CONVERGED_MACHINE_TOLERANCE;
}
// Allow for tolerances set by the user.
if (test3 <= ctol) {
stop = StopCode.CONDITION;
}
if (test2 <= aTolerance) {
stop = StopCode.CONVERGED;
}
if (test1 <= rtol) {
stop = StopCode.TRIVIAL;
}
// See if it is time to print something.
if (log.isDebugEnabled()) {
if ((n <= 40) || (iteration <= 10) || (iteration >= iterationLimit - 10) || ((iteration % 10) == 0)
|| (test3 <= 1.1 * ctol) || (test2 <= 1.1 * aTolerance) || (test1 <= 1.1 * rtol)
|| (stop != StopCode.CONTINUE)) {
statusDump(x, normA, condA, test1, test2);
}
}
} // iteration loop
// Print the stopping condition.
log.debug("Finished: {}", stop.getMessage());
return x;
/*
if show
fprintf('\n\nLSMR finished')
fprintf('\n%s', msg(istop+1,:))
fprintf('\nistop =%8g normr =%8.1e' , istop, normr )
fprintf(' normA =%8.1e normAr =%8.1e', normA, normAr)
fprintf('\nitn =%8g condA =%8.1e' , itn , condA )
fprintf(' normx =%8.1e\n', normx)
end
*/
}