public void nextIteration()

in opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java [252:396]


  public void nextIteration(int iteration) throws IOException {
    iteration--; //move to 0-based index
    int numCorrect = 0;
    int oei = 0;
    int si = 0;
    List<Map<String, Float>> featureCounts = new ArrayList<>(numOutcomes);
    for (int oi = 0; oi < numOutcomes; oi++) {
      featureCounts.add(new HashMap<>());
    }
    PerceptronModel model = new PerceptronModel(params, predLabels, outcomeLabels);

    sequenceStream.reset();

    Sequence<Event> sequence;
    while ((sequence = sequenceStream.read()) != null) {
      Event[] taggerEvents = sequenceStream.updateContext(sequence, model);
      Event[] events = sequence.getEvents();
      boolean update = false;
      for (int ei = 0; ei < events.length; ei++, oei++) {
        if (!taggerEvents[ei].getOutcome().equals(events[ei].getOutcome())) {
          update = true;
          //break;
        } else {
          numCorrect++;
        }
      }
      if (update) {
        for (int oi = 0; oi < numOutcomes; oi++) {
          featureCounts.get(oi).clear();
        }
        if (logger.isTraceEnabled()) {
          final StringBuilder sb = new StringBuilder();
          for (Event event : events) {
            sb.append(" ").append(event.getOutcome());
          }
          logger.trace("train: {}", sb);
        }

        //training feature count computation
        for (int ei = 0; ei < events.length; ei++, oei++) {
          String[] contextStrings = events[ei].getContext();
          float[] values = events[ei].getValues();
          int oi = omap.get(events[ei].getOutcome());
          for (int ci = 0; ci < contextStrings.length; ci++) {
            float value = 1;
            if (values != null) {
              value = values[ci];
            }
            Float c = featureCounts.get(oi).get(contextStrings[ci]);
            if (c == null) {
              c = value;
            } else {
              c += value;
            }
            featureCounts.get(oi).put(contextStrings[ci], c);
          }
        }
        //evaluation feature count computation
        if (logger.isTraceEnabled()) {
          final StringBuilder sb = new StringBuilder();
          for (Event taggerEvent : taggerEvents) {
            sb.append(" ").append(taggerEvent.getOutcome());
          }
          logger.trace("test: {}", sb);
        }
        for (Event taggerEvent : taggerEvents) {
          String[] contextStrings = taggerEvent.getContext();
          float[] values = taggerEvent.getValues();
          int oi = omap.get(taggerEvent.getOutcome());
          for (int ci = 0; ci < contextStrings.length; ci++) {
            float value = 1;
            if (values != null) {
              value = values[ci];
            }
            Float c = featureCounts.get(oi).get(contextStrings[ci]);
            if (c == null) {
              c = -1 * value;
            } else {
              c -= value;
            }
            if (c == 0f) {
              featureCounts.get(oi).remove(contextStrings[ci]);
            } else {
              featureCounts.get(oi).put(contextStrings[ci], c);
            }
          }
        }
        for (int oi = 0; oi < numOutcomes; oi++) {
          for (String feature : featureCounts.get(oi).keySet()) {
            int pi = pmap.getOrDefault(feature, -1);
            if (pi != -1) {
              if (logger.isTraceEnabled()) {
                logger.trace("{} {} {} {}",
                    si, outcomeLabels[oi], feature, featureCounts.get(oi).get(feature));
              }
              params[pi].updateParameter(oi, featureCounts.get(oi).get(feature));
              if (useAverage) {
                if (updates[pi][oi][VALUE] != 0) {
                  averageParams[pi].updateParameter(oi, updates[pi][oi][VALUE] * (numSequences
                      * (iteration - updates[pi][oi][ITER]) + (si - updates[pi][oi][EVENT])));
                  if (logger.isTraceEnabled()) {
                    logger.trace("p avp[{}].{}={}", pi, oi, averageParams[pi].getParameters()[oi]);
                  }
                }
                if (logger.isTraceEnabled()) {
                  logger.trace("p updates[{}]{{}]=({},{},{})({},{},{}) -> {}", pi, oi, updates[pi][oi][ITER],
                      updates[pi][oi][EVENT], updates[pi][oi][VALUE], iteration, oei,
                      params[pi].getParameters()[oi], averageParams[pi].getParameters()[oi]);
                }
                updates[pi][oi][VALUE] = (int) params[pi].getParameters()[oi];
                updates[pi][oi][ITER] = iteration;
                updates[pi][oi][EVENT] = si;
              }
            }
          }
        }
        model = new PerceptronModel(params, predLabels, outcomeLabels);
      }
      si++;
    }
    //finish average computation
    double totIterations = (double) iterations * si;
    if (useAverage && iteration == iterations - 1) {
      for (int pi = 0; pi < numPreds; pi++) {
        double[] predParams = averageParams[pi].getParameters();
        for (int oi = 0; oi < numOutcomes; oi++) {
          if (updates[pi][oi][VALUE] != 0) {
            predParams[oi] += updates[pi][oi][VALUE] * (numSequences
                * (iterations - updates[pi][oi][ITER]) - updates[pi][oi][EVENT]);
          }
          if (predParams[oi] != 0) {
            predParams[oi] /= totIterations;
            averageParams[pi].setParameter(oi, predParams[oi]);
            if (logger.isTraceEnabled()) {
              logger.trace("updates[{}][{}]=({},{},{})({},{},{}) -> {}", pi, oi, updates[pi][oi][ITER],
                  updates[pi][oi][EVENT], updates[pi][oi][VALUE], iterations, 0,
                  params[pi].getParameters()[oi], averageParams[pi].getParameters()[oi]);
            }
          }
        }
      }
    }
    logger.info("{}. ({}/{}) {}", iteration, numCorrect,
        numEvents, ((double) numCorrect / numEvents));
  }