public GISModel trainModel()

in opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java [345:490]


  public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int threads) {

    if (threads <= 0) {
      throw new IllegalArgumentException("threads must be at least one or greater but is " + threads + "!");
    }

    modelExpects = new MutableContext[threads][];

    /* Incorporate all of the needed info *****/
    logger.info("Incorporating indexed data for training...");
    contexts = di.getContexts();
    values = di.getValues();
    /*
    The number of times a predicate occurred in the training data.
   */
    int[] predicateCounts = di.getPredCounts();
    numTimesEventsSeen = di.getNumTimesEventsSeen();
    numUniqueEvents = contexts.length;
    this.prior = modelPrior;
    //printTable(contexts);

    // determine the correction constant and its inverse
    double correctionConstant = 0;
    for (int ci = 0; ci < contexts.length; ci++) {
      if (values == null || values[ci] == null) {
        if (contexts[ci].length > correctionConstant) {
          correctionConstant = contexts[ci].length;
        }
      } else {
        float cl = values[ci][0];
        for (int vi = 1; vi < values[ci].length; vi++) {
          cl += values[ci][vi];
        }

        if (cl > correctionConstant) {
          correctionConstant = cl;
        }
      }
    }
    logger.info("done.");

    outcomeLabels = di.getOutcomeLabels();
    outcomeList = di.getOutcomeList();
    numOutcomes = outcomeLabels.length;

    predLabels = di.getPredLabels();
    prior.setLabels(outcomeLabels, predLabels);
    numPreds = predLabels.length;

    logger.info("\tNumber of Event Tokens: {} " +
        "\n\t Number of Outcomes: {} " +
        "\n\t Number of Predicates: {}", numUniqueEvents, numOutcomes, numPreds);

    // set up feature arrays
    float[][] predCount = new float[numPreds][numOutcomes];
    for (int ti = 0; ti < numUniqueEvents; ti++) {
      for (int j = 0; j < contexts[ti].length; j++) {
        if (values != null && values[ti] != null) {
          predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti] * values[ti][j];
        } else {
          predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti];
        }
      }
    }

    // A fake "observation" to cover features which are not detected in
    // the data.  The default is to assume that we observed "1/10th" of a
    // feature during training.
    final double smoothingObservation = _smoothingObservation;

    // Get the observed expectations of the features. Strictly speaking,
    // we should divide the counts by the number of Tokens, but because of
    // the way the model's expectations are approximated in the
    // implementation, this is cancelled out when we compute the next
    // iteration of a parameter, making the extra divisions wasteful.
    params = new MutableContext[numPreds];
    for (int i = 0; i < modelExpects.length; i++) {
      modelExpects[i] = new MutableContext[numPreds];
    }
    observedExpects = new MutableContext[numPreds];

    // The model does need the correction constant and the correction feature. The correction constant
    // is only needed during training, and the correction feature is not necessary.
    // For compatibility reasons the model contains form now on a correction constant of 1,
    // and a correction param 0.
    evalParams = new EvalParameters(params, numOutcomes);
    int[] activeOutcomes = new int[numOutcomes];
    int[] outcomePattern;
    int[] allOutcomesPattern = new int[numOutcomes];
    for (int oi = 0; oi < numOutcomes; oi++) {
      allOutcomesPattern[oi] = oi;
    }
    int numActiveOutcomes;
    for (int pi = 0; pi < numPreds; pi++) {
      numActiveOutcomes = 0;
      if (useSimpleSmoothing) {
        numActiveOutcomes = numOutcomes;
        outcomePattern = allOutcomesPattern;
      } else { //determine active outcomes
        for (int oi = 0; oi < numOutcomes; oi++) {
          if (predCount[pi][oi] > 0) {
            activeOutcomes[numActiveOutcomes] = oi;
            numActiveOutcomes++;
          }
        }
        if (numActiveOutcomes == numOutcomes) {
          outcomePattern = allOutcomesPattern;
        } else {
          outcomePattern = new int[numActiveOutcomes];
          System.arraycopy(activeOutcomes, 0, outcomePattern, 0, numActiveOutcomes);
        }
      }
      params[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
      for (int i = 0; i < modelExpects.length; i++) {
        modelExpects[i][pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
      }
      observedExpects[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
      for (int aoi = 0; aoi < numActiveOutcomes; aoi++) {
        int oi = outcomePattern[aoi];
        params[pi].setParameter(aoi, 0.0);
        for (MutableContext[] modelExpect : modelExpects) {
          modelExpect[pi].setParameter(aoi, 0.0);
        }
        if (predCount[pi][oi] > 0) {
          observedExpects[pi].setParameter(aoi, predCount[pi][oi]);
        } else if (useSimpleSmoothing) {
          observedExpects[pi].setParameter(aoi, smoothingObservation);
        }
      }
    }

    logger.info("...done.");

    /* Find the parameters *****/
    if (threads == 1) {
      logger.info("Computing model parameters ...");
    } else {
      logger.info("Computing model parameters in {} threads...", threads);
    }

    findParameters(iterations, correctionConstant);

    // Create and return the model
    return new GISModel(params, predLabels, outcomeLabels);

  }