inline void LoadCheckpoint()

in awstreamer/gst_plugins/mxnet/src/common.hpp [88:116]


inline void LoadCheckpoint(const std::string prefix, const unsigned int epoch,
                           Symbol* symbol, std::map<std::string, NDArray>* arg_params,
                           std::map<std::string, NDArray>* aux_params,
                           Context ctx = Context::cpu()) {
    // load symbol from JSON
    Symbol new_symbol = Symbol::Load(prefix + "-symbol.json");
    // load parameters
    std::stringstream ss;
    ss << std::setw(4) << std::setfill('0') << epoch;
    std::string filepath = prefix + "-" + ss.str() + ".params";
    std::map<std::string, NDArray> params = NDArray::LoadToMap(filepath);
    std::map<std::string, NDArray> args;
    std::map<std::string, NDArray> auxs;
    for (auto iter : params) {
        std::string type = iter.first.substr(0, 4);
        std::string name = iter.first.substr(4);
        if (type == "arg:")
            args[name] = iter.second.Copy(ctx);
        else if (type == "aux:")
            auxs[name] = iter.second.Copy(ctx);
        else
            continue;
    }
    NDArray::WaitAll();

    *symbol = new_symbol;
    *arg_params = args;
    *aux_params = auxs;
}