in src/utils/args.cpp [71:253]
void Args::parseArgs(int argc, char** argv) {
if (argc <= 1) {
cerr << "Usage: need to specify whether it is train or test.\n";
printHelp();
exit(EXIT_FAILURE);
}
if (strcmp(argv[1], "train") == 0) {
isTrain = true;
} else if (strcmp(argv[1], "test") == 0) {
isTrain = false;
} else if (strcmp(argv[1], "-h") == 0 || strcmp(argv[1], "-help") == 0) {
std::cerr << "Here is the help! Usage:" << std::endl;
printHelp();
exit(EXIT_FAILURE);
} else {
cerr << "Usage: the first argument should be either train or test.\n";
printHelp();
exit(EXIT_FAILURE);
}
int i = 2;
while (i < argc) {
if (argv[i][0] != '-') {
cout << "Provided argument without a dash! Usage:" << endl;
printHelp();
exit(EXIT_FAILURE);
}
// handling "--"
if (strlen(argv[i]) >= 2 && argv[i][1] == '-') {
argv[i] = argv[i] + 1;
}
if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "-help") == 0) {
std::cerr << "Here is the help! Usage:" << std::endl;
printHelp();
exit(EXIT_FAILURE);
} else if (strcmp(argv[i], "-trainFile") == 0) {
trainFile = string(argv[i + 1]);
} else if (strcmp(argv[i], "-validationFile") == 0) {
validationFile = string(argv[i + 1]);
} else if (strcmp(argv[i], "-testFile") == 0) {
testFile = string(argv[i + 1]);
} else if (strcmp(argv[i], "-predictionFile") == 0) {
predictionFile = string(argv[i + 1]);
} else if (strcmp(argv[i], "-basedoc") == 0) {
basedoc = string(argv[i + 1]);
} else if (strcmp(argv[i], "-model") == 0) {
model = string(argv[i + 1]);
} else if (strcmp(argv[i], "-initModel") == 0) {
initModel = string(argv[i + 1]);
} else if (strcmp(argv[i], "-fileFormat") == 0) {
fileFormat = string(argv[i + 1]);
} else if (strcmp(argv[i], "-compressFile") == 0) {
compressFile = string(argv[i + 1]);
} else if (strcmp(argv[i], "-numGzFile") == 0) {
numGzFile = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-label") == 0) {
label = string(argv[i + 1]);
} else if (strcmp(argv[i], "-weightSep") == 0) {
weightSep = argv[i + 1][0];
} else if (strcmp(argv[i], "-loss") == 0) {
loss = string(argv[i + 1]);
} else if (strcmp(argv[i], "-similarity") == 0) {
similarity = string(argv[i + 1]);
} else if (strcmp(argv[i], "-lr") == 0) {
lr = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-p") == 0) {
p = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-termLr") == 0) {
termLr = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-norm") == 0) {
norm = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-margin") == 0) {
margin = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-initRandSd") == 0) {
initRandSd = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-dropoutLHS") == 0) {
dropoutLHS = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-dropoutRHS") == 0) {
dropoutRHS = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-wordWeight") == 0) {
wordWeight = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-dim") == 0) {
dim = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-epoch") == 0) {
epoch = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-ws") == 0) {
ws = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-maxTrainTime") == 0) {
maxTrainTime = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-validationPatience") == 0) {
validationPatience = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-thread") == 0) {
thread = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-maxNegSamples") == 0) {
maxNegSamples = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-negSearchLimit") == 0) {
negSearchLimit = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-minCount") == 0) {
minCount = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-minCountLabel") == 0) {
minCountLabel = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-bucket") == 0) {
bucket = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-ngrams") == 0) {
ngrams = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-K") == 0) {
K = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-batchSize") == 0) {
batchSize = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-trainMode") == 0) {
trainMode = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-verbose") == 0) {
verbose = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-debug") == 0) {
debug = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-adagrad") == 0) {
adagrad = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-shareEmb") == 0) {
shareEmb = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-normalizeText") == 0) {
normalizeText = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-saveEveryEpoch") == 0) {
saveEveryEpoch = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-saveTempModel") == 0) {
saveTempModel = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-useWeight") == 0) {
useWeight = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-trainWord") == 0) {
trainWord = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-excludeLHS") == 0) {
excludeLHS = isTrue(string(argv[i + 1]));
} else {
cerr << "Unknown argument: " << argv[i] << std::endl;
printHelp();
exit(EXIT_FAILURE);
}
i += 2;
}
if (isTrain) {
if (trainFile.empty() || model.empty()) {
cerr << "Empty train file or output model path." << endl;
printHelp();
exit(EXIT_FAILURE);
}
} else {
if (testFile.empty() || model.empty()) {
cerr << "Empty test file or model path." << endl;
printHelp();
exit(EXIT_FAILURE);
}
}
// check for trainMode
if ((trainMode < 0) || (trainMode > 5)) {
cerr << "Uknown trainMode. We currently support the follow train modes:\n";
cerr << "trainMode 0: at training time, one label from RHS is picked as true label; LHS is the same from input.\n";
cerr << "trainMode 1: at training time, one label from RHS is picked as true label; LHS is the bag of the rest RHS labels.\n";
cerr << "trainMode 2: at training time, one label from RHS is picked as LHS; the bag of the rest RHS labels becomes the true label.\n";
cerr << "trainMode 3: at training time, one label from RHS is picked as true label and another label from RHS is picked as LHS.\n";
cerr << "trainMode 4: at training time, the first label from RHS is picked as LHS and the second one picked as true label.\n";
cerr << "trainMode 5: continuous bag of words training.\n";
exit(EXIT_FAILURE);
}
// check for loss type
if (!(loss == "hinge" || loss == "softmax")) {
cerr << "Unsupported loss type: " << loss << endl;
exit(EXIT_FAILURE);
}
// check for similarity type
if (!(similarity == "cosine" || similarity == "dot")) {
cerr << "Unsupported similarity type. Should be either dot or cosine.\n";
exit(EXIT_FAILURE);
}
// check for file format
if (!(fileFormat == "fastText" || fileFormat == "labelDoc")) {
cerr << "Unsupported file format type. Should be either fastText or labelDoc.\n";
exit(EXIT_FAILURE);
}
if (!(compressFile.empty() || compressFile == "gzip")) {
cerr << "Currently only support gzip for compressedFile.\n";
exit(EXIT_FAILURE);
}
}