in opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java [252:396]
public void nextIteration(int iteration) throws IOException {
iteration--; //move to 0-based index
int numCorrect = 0;
int oei = 0;
int si = 0;
List<Map<String, Float>> featureCounts = new ArrayList<>(numOutcomes);
for (int oi = 0; oi < numOutcomes; oi++) {
featureCounts.add(new HashMap<>());
}
PerceptronModel model = new PerceptronModel(params, predLabels, outcomeLabels);
sequenceStream.reset();
Sequence<Event> sequence;
while ((sequence = sequenceStream.read()) != null) {
Event[] taggerEvents = sequenceStream.updateContext(sequence, model);
Event[] events = sequence.getEvents();
boolean update = false;
for (int ei = 0; ei < events.length; ei++, oei++) {
if (!taggerEvents[ei].getOutcome().equals(events[ei].getOutcome())) {
update = true;
//break;
} else {
numCorrect++;
}
}
if (update) {
for (int oi = 0; oi < numOutcomes; oi++) {
featureCounts.get(oi).clear();
}
if (logger.isTraceEnabled()) {
final StringBuilder sb = new StringBuilder();
for (Event event : events) {
sb.append(" ").append(event.getOutcome());
}
logger.trace("train: {}", sb);
}
//training feature count computation
for (int ei = 0; ei < events.length; ei++, oei++) {
String[] contextStrings = events[ei].getContext();
float[] values = events[ei].getValues();
int oi = omap.get(events[ei].getOutcome());
for (int ci = 0; ci < contextStrings.length; ci++) {
float value = 1;
if (values != null) {
value = values[ci];
}
Float c = featureCounts.get(oi).get(contextStrings[ci]);
if (c == null) {
c = value;
} else {
c += value;
}
featureCounts.get(oi).put(contextStrings[ci], c);
}
}
//evaluation feature count computation
if (logger.isTraceEnabled()) {
final StringBuilder sb = new StringBuilder();
for (Event taggerEvent : taggerEvents) {
sb.append(" ").append(taggerEvent.getOutcome());
}
logger.trace("test: {}", sb);
}
for (Event taggerEvent : taggerEvents) {
String[] contextStrings = taggerEvent.getContext();
float[] values = taggerEvent.getValues();
int oi = omap.get(taggerEvent.getOutcome());
for (int ci = 0; ci < contextStrings.length; ci++) {
float value = 1;
if (values != null) {
value = values[ci];
}
Float c = featureCounts.get(oi).get(contextStrings[ci]);
if (c == null) {
c = -1 * value;
} else {
c -= value;
}
if (c == 0f) {
featureCounts.get(oi).remove(contextStrings[ci]);
} else {
featureCounts.get(oi).put(contextStrings[ci], c);
}
}
}
for (int oi = 0; oi < numOutcomes; oi++) {
for (String feature : featureCounts.get(oi).keySet()) {
int pi = pmap.getOrDefault(feature, -1);
if (pi != -1) {
if (logger.isTraceEnabled()) {
logger.trace("{} {} {} {}",
si, outcomeLabels[oi], feature, featureCounts.get(oi).get(feature));
}
params[pi].updateParameter(oi, featureCounts.get(oi).get(feature));
if (useAverage) {
if (updates[pi][oi][VALUE] != 0) {
averageParams[pi].updateParameter(oi, updates[pi][oi][VALUE] * (numSequences
* (iteration - updates[pi][oi][ITER]) + (si - updates[pi][oi][EVENT])));
if (logger.isTraceEnabled()) {
logger.trace("p avp[{}].{}={}", pi, oi, averageParams[pi].getParameters()[oi]);
}
}
if (logger.isTraceEnabled()) {
logger.trace("p updates[{}]{{}]=({},{},{})({},{},{}) -> {}", pi, oi, updates[pi][oi][ITER],
updates[pi][oi][EVENT], updates[pi][oi][VALUE], iteration, oei,
params[pi].getParameters()[oi], averageParams[pi].getParameters()[oi]);
}
updates[pi][oi][VALUE] = (int) params[pi].getParameters()[oi];
updates[pi][oi][ITER] = iteration;
updates[pi][oi][EVENT] = si;
}
}
}
}
model = new PerceptronModel(params, predLabels, outcomeLabels);
}
si++;
}
//finish average computation
double totIterations = (double) iterations * si;
if (useAverage && iteration == iterations - 1) {
for (int pi = 0; pi < numPreds; pi++) {
double[] predParams = averageParams[pi].getParameters();
for (int oi = 0; oi < numOutcomes; oi++) {
if (updates[pi][oi][VALUE] != 0) {
predParams[oi] += updates[pi][oi][VALUE] * (numSequences
* (iterations - updates[pi][oi][ITER]) - updates[pi][oi][EVENT]);
}
if (predParams[oi] != 0) {
predParams[oi] /= totIterations;
averageParams[pi].setParameter(oi, predParams[oi]);
if (logger.isTraceEnabled()) {
logger.trace("updates[{}][{}]=({},{},{})({},{},{}) -> {}", pi, oi, updates[pi][oi][ITER],
updates[pi][oi][EVENT], updates[pi][oi][VALUE], iterations, 0,
params[pi].getParameters()[oi], averageParams[pi].getParameters()[oi]);
}
}
}
}
}
logger.info("{}. ({}/{}) {}", iteration, numCorrect,
numEvents, ((double) numCorrect / numEvents));
}