in modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java [419:472]
private Optional<GmmModel> init(Dataset<EmptyContext, GmmPartitionData> dataset) {
int cntOfTries = 0;
while (true) {
try {
if (initialMeans == null) {
List<Vector> randomMeansSets = Stream.of(dataset.compute(
selectNRandomXsMapper(countOfComponents),
GmmTrainer::selectNRandomXsReducer))
.flatMap(Stream::of)
.sorted(Comparator.comparingDouble(Vector::getLengthSquared)).collect(Collectors.toList());
Collections.shuffle(randomMeansSets, environment.randomNumbersGenerator());
A.ensure(
randomMeansSets.size() >= countOfComponents,
"There is not enough data in dataset for select N random means"
);
initialMeans = randomMeansSets.subList(0, countOfComponents)
.toArray(new Vector[countOfComponents]);
}
dataset.compute(data -> GmmPartitionData.estimateLikelihoodClusters(data, initialMeans));
List<Matrix> initCovs = CovarianceMatricesAggregator.computeCovariances(
dataset,
VectorUtils.fill(1. / countOfComponents, countOfComponents),
initialMeans
);
if (initCovs.isEmpty())
return Optional.empty();
List<MultivariateGaussianDistribution> distributions = new ArrayList<>();
for (int i = 0; i < countOfComponents; i++)
distributions.add(new MultivariateGaussianDistribution(initialMeans[i], initCovs.get(i)));
return Optional.of(new GmmModel(
VectorUtils.of(DoubleStream.generate(() -> 1. / countOfComponents).limit(countOfComponents).toArray()),
distributions
));
}
catch (SingularMatrixException | IllegalArgumentException e) {
String msg = "Cannot construct non-singular covariance matrix by data. " +
"Try to select other initial means or other model trainer [number of tries = " + cntOfTries + "]";
environment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
cntOfTries += 1;
initialMeans = null;
if (cntOfTries >= maxCountOfInitTries)
throw new RuntimeException(msg, e);
}
}
}