in src/starspace.cpp [168:213]
void StarSpace::train() {
float rate = args_->lr;
float decrPerEpoch = (rate - 1e-9) / args_->epoch;
int impatience = 0;
float best_valid_err = 1e9;
auto t_start = std::chrono::high_resolution_clock::now();
for (int i = 0; i < args_->epoch; i++) {
if (args_->saveEveryEpoch && i > 0) {
auto filename = args_->model;
if (args_->saveTempModel) {
filename = filename + "_epoch" + std::to_string(i);
}
saveModel(filename);
saveModelTsv(filename + ".tsv");
}
cout << "Training epoch " << i << ": " << rate << ' ' << decrPerEpoch << endl;
auto err = model_->train(trainData_, args_->thread,
t_start, i,
rate, rate - decrPerEpoch);
printf("\n ---+++ %20s %4d Train error : %3.8f +++--- %c%c%c\n",
"Epoch", i, err,
0xe2, 0x98, 0x83);
if (validData_ != nullptr) {
auto valid_err = model_->test(validData_, args_->thread);
cout << "\nValidation error: " << valid_err << endl;
if (valid_err > best_valid_err) {
impatience += 1;
if (impatience > args_->validationPatience) {
cout << "Ran out of Patience! Early stopping based on validation set." << endl;
break;
}
} else {
best_valid_err = valid_err;
}
}
rate -= decrPerEpoch;
auto t_end = std::chrono::high_resolution_clock::now();
auto tot_spent = std::chrono::duration<double>(t_end-t_start).count();
if (tot_spent >args_->maxTrainTime) {
cout << "MaxTrainTime exceeded." << endl;
break;
}
}
}