in opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java [566:654]
private double nextIteration(double correctionConstant,
CompletionService<ModelExpectationComputeTask> completionService,
int iteration) {
// compute contribution of p(a|b_i) for each feature and the new
// correction parameter
double loglikelihood = 0.0;
int numEvents = 0;
int numCorrect = 0;
// Each thread gets equal number of tasks, if the number of tasks
// is not divisible by the number of threads, the first "leftOver"
// threads have one extra task.
int numberOfThreads = modelExpects.length;
int taskSize = numUniqueEvents / numberOfThreads;
int leftOver = numUniqueEvents % numberOfThreads;
// submit all tasks to the completion service.
for (int i = 0; i < numberOfThreads; i++) {
if (i < leftOver) {
completionService.submit(new ModelExpectationComputeTask(i, i * taskSize + i,
taskSize + 1));
} else {
completionService.submit(new ModelExpectationComputeTask(i,
i * taskSize + leftOver, taskSize));
}
}
for (int i = 0; i < numberOfThreads; i++) {
ModelExpectationComputeTask finishedTask;
try {
finishedTask = completionService.take().get();
} catch (InterruptedException e) {
// TODO: We got interrupted, but that is currently not really supported!
// For now we fail hard. We hopefully soon
// handle this case properly!
throw new IllegalStateException("Interruption is not supported!", e);
} catch (ExecutionException e) {
// Only runtime exception can be thrown during training, if one was thrown
// it should be re-thrown. That could for example be a NullPointerException
// which is caused through a bug in our implementation.
throw new RuntimeException("Exception during training: " + e.getMessage(), e);
}
// When they are done, retrieve the results ...
numEvents += finishedTask.getNumEvents();
numCorrect += finishedTask.getNumCorrect();
loglikelihood += finishedTask.getLoglikelihood();
}
// merge the results of the two computations
for (int pi = 0; pi < numPreds; pi++) {
int[] activeOutcomes = params[pi].getOutcomes();
for (int aoi = 0; aoi < activeOutcomes.length; aoi++) {
for (int i = 1; i < modelExpects.length; i++) {
modelExpects[0][pi].updateParameter(aoi, modelExpects[i][pi].getParameters()[aoi]);
}
}
}
// compute the new parameter values
for (int pi = 0; pi < numPreds; pi++) {
double[] observed = observedExpects[pi].getParameters();
double[] model = modelExpects[0][pi].getParameters();
int[] activeOutcomes = params[pi].getOutcomes();
for (int aoi = 0; aoi < activeOutcomes.length; aoi++) {
if (useGaussianSmoothing) {
params[pi].updateParameter(aoi, gaussianUpdate(pi, aoi, correctionConstant));
} else {
if (model[aoi] == 0) {
logger.warn("Model expects == 0 for {} {}", predLabels[pi], outcomeLabels[aoi]);
}
//params[pi].updateParameter(aoi,(StrictMath.log(observed[aoi]) - StrictMath.log(model[aoi])));
params[pi].updateParameter(aoi, ((StrictMath.log(observed[aoi]) - StrictMath.log(model[aoi]))
/ correctionConstant));
}
for (MutableContext[] modelExpect : modelExpects) {
modelExpect[pi].setParameter(aoi, 0.0); // re-initialize to 0.0's
}
}
}
getTrainingProgressMonitor(trainingConfiguration).
finishedIteration(iteration, numCorrect, numEvents, TrainingMeasure.LOG_LIKELIHOOD, loglikelihood);
return loglikelihood;
}