void StarSpace::evaluate()

in src/starspace.cpp [417:490]


void StarSpace::evaluate() {
  // check that it is not in trainMode 5
  if (args_->trainMode == 5) {
    std::cerr << "Test is undefined in trainMode 5. Please use other trainMode for testing.\n";
    exit(EXIT_FAILURE);
  }

  // set dropout probability to 0 in test case
  args_->dropoutLHS = 0.0;
  args_->dropoutRHS = 0.0;

  loadBaseDocs();
  int N = testData_->getSize();

  auto numThreads = args_->thread;
  vector<thread> threads;
  vector<Metrics> metrics(numThreads);
  vector<vector<Predictions>> predictions(N);
  int numPerThread = ceil((float) N / numThreads);
  assert(numPerThread > 0);

  vector<ParseResults> examples;
  testData_->getNextKExamples(N, examples);

  auto evalThread = [&] (int idx, int start, int end) {
    metrics[idx].clear();
    for (int i = start; i < end; i++) {
      auto s = evaluateOne(examples[i].LHSTokens, examples[i].RHSTokens, predictions[i], args_->excludeLHS);
      metrics[idx].add(s);
    }
  };

  for (int i = 0; i < numThreads; i++) {
    auto start = std::min(i * numPerThread, N);
    auto end = std::min(start + numPerThread, N);
    assert(end >= start);
    threads.emplace_back(thread([=] {
      evalThread(i, start, end);
    }));
  }
  for (auto& t : threads) t.join();

  Metrics result;
  result.clear();
  for (int i = 0; i < numThreads; i++) {
    if (args_->debug) { metrics[i].print(); }
    result.add(metrics[i]);
  }
  result.average();
  result.print();

  if (!args_->predictionFile.empty()) {
    // print out prediction results to file
    ofstream ofs(args_->predictionFile);
    for (int i = 0; i < N; i++) {
      ofs << "Example " << i << ":\nLHS:\n";
      printDoc(ofs, examples[i].LHSTokens);
      ofs << "RHS: \n";
      printDoc(ofs, examples[i].RHSTokens);
      ofs << "Predictions: \n";
      for (auto pred : predictions[i]) {
        if (pred.second == 0) {
          ofs << "(++) [" << pred.first << "]\t";
          printDoc(ofs, examples[i].RHSTokens);
        } else {
          ofs << "(--) [" << pred.first << "]\t";
          printDoc(ofs, baseDocs_[pred.second - 1]);
        }
      }
      ofs << "\n";
    }
    ofs.close();
  }
}