in commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java [132:259]
public void fit(final MixtureMultivariateNormalDistribution initialMixture,
final int maxIterations,
final double threshold)
throws SingularMatrixException,
NotStrictlyPositiveException,
DimensionMismatchException {
if (maxIterations < 1) {
throw new NotStrictlyPositiveException(maxIterations);
}
if (threshold < Double.MIN_VALUE) {
throw new NotStrictlyPositiveException(threshold);
}
final int n = data.length;
// Number of data columns. Jagged data already rejected in constructor,
// so we can assume the lengths of each row are equal.
final int numCols = data[0].length;
final int k = initialMixture.getComponents().size();
final int numMeanColumns
= initialMixture.getComponents().get(0).getSecond().getMeans().length;
if (numMeanColumns != numCols) {
throw new DimensionMismatchException(numMeanColumns, numCols);
}
int numIterations = 0;
double previousLogLikelihood = 0d;
logLikelihood = Double.NEGATIVE_INFINITY;
// Initialize model to fit to initial mixture.
fittedModel = new MixtureMultivariateNormalDistribution(initialMixture.getComponents());
while (numIterations++ <= maxIterations &&
JdkMath.abs(previousLogLikelihood - logLikelihood) > threshold) {
previousLogLikelihood = logLikelihood;
double sumLogLikelihood = 0d;
// Mixture components
final List<Pair<Double, MultivariateNormalDistribution>> components
= fittedModel.getComponents();
// Weight and distribution of each component
final double[] weights = new double[k];
final MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[k];
for (int j = 0; j < k; j++) {
weights[j] = components.get(j).getFirst();
mvns[j] = components.get(j).getSecond();
}
// E-step: compute the data dependent parameters of the expectation
// function.
// The percentage of row's total density between a row and a
// component
final double[][] gamma = new double[n][k];
// Sum of gamma for each component
final double[] gammaSums = new double[k];
// Sum of gamma times its row for each each component
final double[][] gammaDataProdSums = new double[k][numCols];
for (int i = 0; i < n; i++) {
final double rowDensity = fittedModel.density(data[i]);
sumLogLikelihood += JdkMath.log(rowDensity);
for (int j = 0; j < k; j++) {
gamma[i][j] = weights[j] * mvns[j].density(data[i]) / rowDensity;
gammaSums[j] += gamma[i][j];
for (int col = 0; col < numCols; col++) {
gammaDataProdSums[j][col] += gamma[i][j] * data[i][col];
}
}
}
logLikelihood = sumLogLikelihood / n;
// M-step: compute the new parameters based on the expectation
// function.
final double[] newWeights = new double[k];
final double[][] newMeans = new double[k][numCols];
for (int j = 0; j < k; j++) {
newWeights[j] = gammaSums[j] / n;
for (int col = 0; col < numCols; col++) {
newMeans[j][col] = gammaDataProdSums[j][col] / gammaSums[j];
}
}
// Compute new covariance matrices
final RealMatrix[] newCovMats = new RealMatrix[k];
for (int j = 0; j < k; j++) {
newCovMats[j] = new Array2DRowRealMatrix(numCols, numCols);
}
for (int i = 0; i < n; i++) {
for (int j = 0; j < k; j++) {
final RealMatrix vec
= new Array2DRowRealMatrix(MathArrays.ebeSubtract(data[i], newMeans[j]));
final RealMatrix dataCov
= vec.multiply(vec.transpose()).scalarMultiply(gamma[i][j]);
newCovMats[j] = newCovMats[j].add(dataCov);
}
}
// Converting to arrays for use by fitted model
final double[][][] newCovMatArrays = new double[k][numCols][numCols];
for (int j = 0; j < k; j++) {
newCovMats[j] = newCovMats[j].scalarMultiply(1d / gammaSums[j]);
newCovMatArrays[j] = newCovMats[j].getData();
}
// Update current model
fittedModel = new MixtureMultivariateNormalDistribution(newWeights,
newMeans,
newCovMatArrays);
}
if (JdkMath.abs(previousLogLikelihood - logLikelihood) > threshold) {
// Did not converge before the maximum number of iterations
throw new ConvergenceException();
}
}