void loadConvLM()

in recipes/utilities/convlm_serializer/Utils.cpp [272:314]


void loadConvLM(
    shared_ptr<fl::Module>& network,
    shared_ptr<fl::BinaryModule>& criterion,
    const string& archFile,
    const string& weightFile,
    int outputTokensDim,
    const vector<int>& adaptiveTail /*  = std::vector<int>() */,
    int inputSizeAdaptiveSoftmax /* = 0 */) {
  FL_LOG_IF(fl::FATAL, !fl::lib::fileExists(archFile))
      << "Path to arch file " << archFile << " doesn't exist";
  FL_LOG_IF(fl::FATAL, !fl::lib::fileExists(weightFile))
      << "Path to weight file " << weightFile << " doesn't exist";
  // create network and criterion
  network =
      fl::pkg::runtime::buildSequentialModule(archFile, 1, outputTokensDim);
  network->eval();

  if (adaptiveTail.size() > 0) {
    auto activation = make_shared<fl::AdaptiveSoftMax>(
        inputSizeAdaptiveSoftmax, adaptiveTail);
    criterion = make_shared<fl::AdaptiveSoftMaxLoss>(activation);
    criterion->eval();
  } else {
    criterion = nullptr;
  }

  // Loading weights from the binary file
  FL_LOG(fl::INFO) << "[LoadConvLM]: Load states";
  auto modelStates = loadModelStates(weightFile);
  FL_LOG_IF(
      fl::FATAL,
      modelStates.size() !=
          network->params().size() +
              (criterion ? criterion->params().size() : 0))
      << "mismatch between the number of parameters in the arch file and the weight file "
      << modelStates.size() << " model states vs " << network->params().size()
      << " nn params + " << (criterion ? criterion->params().size() : 0)
      << " criterion params";

  // Load weight states into network and criterion
  FL_LOG(fl::INFO) << "[LoadConvLM]: set params";
  setParams(network, criterion, modelStates);
}