int main()

in tools/trainers/retargetTrainer/src/main.cpp [360:517]


int main(int argc, char* argv[])
{
    utilities::MillisecondTimer _overallTimer;
    utilities::MillisecondTimer _timer;
    try
    {
        _overallTimer.Start();
        // create a command line parser
        utilities::CommandLineParser commandLineParser(argc, argv);

        // add arguments to the command line parser
        ParsedRetargetArguments retargetArguments;
        commandLineParser.AddOptionSet(retargetArguments);

        // parse command line
        commandLineParser.Parse();
        if (retargetArguments.verbose) std::cout << commandLineParser.GetCurrentValuesString() << std::endl;

        _timer.Start();
        // load map
        if (retargetArguments.verbose) std::cout << "Loading model from " << retargetArguments.inputModelFilename;
        auto map = common::LoadMap(retargetArguments.inputModelFilename);
        if (retargetArguments.verbose) std::cout << "(" << _timer.Elapsed() << " ms)" << std::endl;

        if (retargetArguments.print)
        {
            PrintModel(map, retargetArguments.refineIterations);
            return 0;
        }

        // Create a map by redirecting a layer or node to be output
        bool redirected = false;
        if (retargetArguments.removeLastLayers > 0)
        {
            if (map.GetOutputType() == model::Port::PortType::smallReal)
            {
                redirected = RedirectNeuralNetworkOutputByLayer<float>(map, retargetArguments.removeLastLayers);
            }
            else
            {
                redirected = RedirectNeuralNetworkOutputByLayer<double>(map, retargetArguments.removeLastLayers);
            }
            std::cout << "Removed last " << retargetArguments.removeLastLayers << " layers from neural network" << std::endl;
        }
        else if (retargetArguments.targetPortElements.length() > 0)
        {
            redirected = RedirectModelOutputByPortElements(map, retargetArguments.targetPortElements, retargetArguments.refineIterations);
            std::cout << "Redirected output for port elements " << retargetArguments.targetPortElements << " from model" << std::endl;
        }
        else
        {
            std::cerr << "Error: Expected valid arguments for either --removeLastLayers or --targetPortElements" << std::endl;
            return 1;
        }

        if (!redirected)
        {
            std::cerr << "Could not splice model, exiting" << std::endl;
            return 1;
        }

        auto node = map.GetOutput(0).GetNode();
        std::cout << "Using output from node of type " << node->GetRuntimeTypeName() << std::endl;

        // load dataset and map the output
        if (retargetArguments.verbose) std::cout << "Loading data ...";
        model::Map result;
        if (retargetArguments.multiClass)
        {
            // This is a multi-class dataset
            _timer.Start();
            auto stream = utilities::OpenIfstream(retargetArguments.inputDataFilename);
            auto multiclassDataset = common::GetMultiClassDataset(stream);
            if (retargetArguments.verbose) std::cout << "(" << _timer.Elapsed() << " ms)" << std::endl;

            // Obtain a new training dataset for the set of Linear Predictors by running the
            // multiclassDataset through the modified model
            if (retargetArguments.verbose) std::cout << std::endl
                                                     << "Transforming dataset with compiled model...";
            _timer.Start();

            auto dataset = common::TransformDatasetWithCompiledMap(multiclassDataset, map, retargetArguments.useBlas);
            if (retargetArguments.verbose) std::cout << "(" << _timer.Elapsed() << " ms)" << std::endl;

            // Create binary classification datasets for each one versus rest (OVR) case
            if (retargetArguments.verbose) std::cout << std::endl
                                                     << "Creating datasets for One vs Rest...";
            _timer.Start();
            auto datasets = CreateDatasetsForOneVersusRest(dataset);
            if (retargetArguments.verbose) std::cout << "(" << _timer.Elapsed() << " ms)" << std::endl;

            // Next, train a binary classifier for each case and combine into a
            // single model.
            _timer.Start();
            std::vector<PredictorType> predictors(datasets.size());
            for (size_t i = 0; i < datasets.size(); ++i)
            {
                std::cout << std::endl
                          << "=== Training binary classifier for class " << i << " vs Rest ===" << std::endl;

                predictors[i] = RetargetModelUsingLinearPredictor(retargetArguments, datasets[i]);
            }
            if (retargetArguments.verbose) std::cout << "Training completed ...(" << _timer.Elapsed() << " ms)" << std::endl;

            // Save the newly spliced model
            result = GetRetargetedModel(predictors, map);
        }
        else
        {
            // This is a binary classification dataset
            _timer.Start();
            auto stream = utilities::OpenIfstream(retargetArguments.inputDataFilename);
            auto binaryDataset = common::GetDataset(stream);
            if (retargetArguments.verbose) std::cout << "Loading dataset took :" << _timer.Elapsed() << " ms" << std::endl;
            // Obtain a new training dataset for the Linear Predictor by running the
            // binaryDataset through the modified model
            if (retargetArguments.verbose) std::cout << std::endl
                                                     << "Transforming dataset with compiled model...";
            _timer.Start();

            auto dataset = common::TransformDatasetWithCompiledMap(binaryDataset, map);
            if (retargetArguments.verbose) std::cout << "(" << _timer.Elapsed() << " ms)" << std::endl;

            // Train a linear predictor whose input comes from the previous model
            _timer.Start();
            auto predictor = RetargetModelUsingLinearPredictor(retargetArguments, dataset);
            if (retargetArguments.verbose) std::cout << "Training completed... (" << _timer.Elapsed() << " ms)" << std::endl;

            // Save the newly spliced model
            result = GetRetargetedModel(predictor, map);
        }
        common::SaveMap(result, retargetArguments.outputModelFilename);
        if (retargetArguments.verbose) std::cout << std::endl
                                                 << "RetargetTrainer completed... (" << _overallTimer.Elapsed() << " ms)" << std::endl;
        std::cout << std::endl
                  << "New model saved as " << retargetArguments.outputModelFilename << std::endl;
    }
    catch (const utilities::CommandLineParserPrintHelpException& exception)
    {
        std::cout << exception.GetHelpText() << std::endl;
        return 0;
    }
    catch (const utilities::CommandLineParserErrorException& exception)
    {
        std::cerr << "Command line parse error:" << std::endl;
        for (const auto& error : exception.GetParseErrors())
        {
            std::cerr << error.GetMessage() << std::endl;
        }
        return 1;
    }
    catch (const utilities::Exception& exception)
    {
        std::cerr << "exception: " << exception.GetMessage() << std::endl;
        return 1;
    }
    return 0;
}