in tools/trainers/retargetTrainer/src/main.cpp [360:517]
int main(int argc, char* argv[])
{
utilities::MillisecondTimer _overallTimer;
utilities::MillisecondTimer _timer;
try
{
_overallTimer.Start();
// create a command line parser
utilities::CommandLineParser commandLineParser(argc, argv);
// add arguments to the command line parser
ParsedRetargetArguments retargetArguments;
commandLineParser.AddOptionSet(retargetArguments);
// parse command line
commandLineParser.Parse();
if (retargetArguments.verbose) std::cout << commandLineParser.GetCurrentValuesString() << std::endl;
_timer.Start();
// load map
if (retargetArguments.verbose) std::cout << "Loading model from " << retargetArguments.inputModelFilename;
auto map = common::LoadMap(retargetArguments.inputModelFilename);
if (retargetArguments.verbose) std::cout << "(" << _timer.Elapsed() << " ms)" << std::endl;
if (retargetArguments.print)
{
PrintModel(map, retargetArguments.refineIterations);
return 0;
}
// Create a map by redirecting a layer or node to be output
bool redirected = false;
if (retargetArguments.removeLastLayers > 0)
{
if (map.GetOutputType() == model::Port::PortType::smallReal)
{
redirected = RedirectNeuralNetworkOutputByLayer<float>(map, retargetArguments.removeLastLayers);
}
else
{
redirected = RedirectNeuralNetworkOutputByLayer<double>(map, retargetArguments.removeLastLayers);
}
std::cout << "Removed last " << retargetArguments.removeLastLayers << " layers from neural network" << std::endl;
}
else if (retargetArguments.targetPortElements.length() > 0)
{
redirected = RedirectModelOutputByPortElements(map, retargetArguments.targetPortElements, retargetArguments.refineIterations);
std::cout << "Redirected output for port elements " << retargetArguments.targetPortElements << " from model" << std::endl;
}
else
{
std::cerr << "Error: Expected valid arguments for either --removeLastLayers or --targetPortElements" << std::endl;
return 1;
}
if (!redirected)
{
std::cerr << "Could not splice model, exiting" << std::endl;
return 1;
}
auto node = map.GetOutput(0).GetNode();
std::cout << "Using output from node of type " << node->GetRuntimeTypeName() << std::endl;
// load dataset and map the output
if (retargetArguments.verbose) std::cout << "Loading data ...";
model::Map result;
if (retargetArguments.multiClass)
{
// This is a multi-class dataset
_timer.Start();
auto stream = utilities::OpenIfstream(retargetArguments.inputDataFilename);
auto multiclassDataset = common::GetMultiClassDataset(stream);
if (retargetArguments.verbose) std::cout << "(" << _timer.Elapsed() << " ms)" << std::endl;
// Obtain a new training dataset for the set of Linear Predictors by running the
// multiclassDataset through the modified model
if (retargetArguments.verbose) std::cout << std::endl
<< "Transforming dataset with compiled model...";
_timer.Start();
auto dataset = common::TransformDatasetWithCompiledMap(multiclassDataset, map, retargetArguments.useBlas);
if (retargetArguments.verbose) std::cout << "(" << _timer.Elapsed() << " ms)" << std::endl;
// Create binary classification datasets for each one versus rest (OVR) case
if (retargetArguments.verbose) std::cout << std::endl
<< "Creating datasets for One vs Rest...";
_timer.Start();
auto datasets = CreateDatasetsForOneVersusRest(dataset);
if (retargetArguments.verbose) std::cout << "(" << _timer.Elapsed() << " ms)" << std::endl;
// Next, train a binary classifier for each case and combine into a
// single model.
_timer.Start();
std::vector<PredictorType> predictors(datasets.size());
for (size_t i = 0; i < datasets.size(); ++i)
{
std::cout << std::endl
<< "=== Training binary classifier for class " << i << " vs Rest ===" << std::endl;
predictors[i] = RetargetModelUsingLinearPredictor(retargetArguments, datasets[i]);
}
if (retargetArguments.verbose) std::cout << "Training completed ...(" << _timer.Elapsed() << " ms)" << std::endl;
// Save the newly spliced model
result = GetRetargetedModel(predictors, map);
}
else
{
// This is a binary classification dataset
_timer.Start();
auto stream = utilities::OpenIfstream(retargetArguments.inputDataFilename);
auto binaryDataset = common::GetDataset(stream);
if (retargetArguments.verbose) std::cout << "Loading dataset took :" << _timer.Elapsed() << " ms" << std::endl;
// Obtain a new training dataset for the Linear Predictor by running the
// binaryDataset through the modified model
if (retargetArguments.verbose) std::cout << std::endl
<< "Transforming dataset with compiled model...";
_timer.Start();
auto dataset = common::TransformDatasetWithCompiledMap(binaryDataset, map);
if (retargetArguments.verbose) std::cout << "(" << _timer.Elapsed() << " ms)" << std::endl;
// Train a linear predictor whose input comes from the previous model
_timer.Start();
auto predictor = RetargetModelUsingLinearPredictor(retargetArguments, dataset);
if (retargetArguments.verbose) std::cout << "Training completed... (" << _timer.Elapsed() << " ms)" << std::endl;
// Save the newly spliced model
result = GetRetargetedModel(predictor, map);
}
common::SaveMap(result, retargetArguments.outputModelFilename);
if (retargetArguments.verbose) std::cout << std::endl
<< "RetargetTrainer completed... (" << _overallTimer.Elapsed() << " ms)" << std::endl;
std::cout << std::endl
<< "New model saved as " << retargetArguments.outputModelFilename << std::endl;
}
catch (const utilities::CommandLineParserPrintHelpException& exception)
{
std::cout << exception.GetHelpText() << std::endl;
return 0;
}
catch (const utilities::CommandLineParserErrorException& exception)
{
std::cerr << "Command line parse error:" << std::endl;
for (const auto& error : exception.GetParseErrors())
{
std::cerr << error.GetMessage() << std::endl;
}
return 1;
}
catch (const utilities::Exception& exception)
{
std::cerr << "exception: " << exception.GetMessage() << std::endl;
return 1;
}
return 0;
}