void StarSpace::train()

in src/starspace.cpp [168:213]


void StarSpace::train() {
  float rate = args_->lr;
  float decrPerEpoch = (rate - 1e-9) / args_->epoch;

  int impatience = 0;
  float best_valid_err = 1e9;
  auto t_start = std::chrono::high_resolution_clock::now();
  for (int i = 0; i < args_->epoch; i++) {
    if (args_->saveEveryEpoch && i > 0) {
      auto filename = args_->model;
      if (args_->saveTempModel) {
        filename = filename + "_epoch" + std::to_string(i);
      }
      saveModel(filename);
      saveModelTsv(filename + ".tsv");
    }
    cout << "Training epoch " << i << ": " << rate << ' ' << decrPerEpoch << endl;
    auto err = model_->train(trainData_, args_->thread,
           t_start,  i,
           rate, rate - decrPerEpoch);
    printf("\n ---+++ %20s %4d Train error : %3.8f +++--- %c%c%c\n",
           "Epoch", i, err,
           0xe2, 0x98, 0x83);
    if (validData_ != nullptr) {
      auto valid_err = model_->test(validData_, args_->thread);
      cout << "\nValidation error: " << valid_err << endl;
      if (valid_err > best_valid_err) {
        impatience += 1;
        if (impatience > args_->validationPatience) {
          cout << "Ran out of Patience! Early stopping based on validation set." << endl;
          break;
        }
      } else {
        best_valid_err = valid_err;
      }
    }
    rate -= decrPerEpoch;

    auto t_end = std::chrono::high_resolution_clock::now();
    auto tot_spent = std::chrono::duration<double>(t_end-t_start).count();
    if (tot_spent >args_->maxTrainTime) {
      cout << "MaxTrainTime exceeded." << endl;
      break;
    }
  }
}