void loadModule()

in recipes/utilities/convlm_serializer/Utils.cpp [160:249]


void loadModule(
    vector<ConvLMParamState>& states,
    shared_ptr<fl::Module> mainModule,
    shared_ptr<fl::Module> subModule,
    int& loadIdx,
    int paramIdx) {
  int nParams = subModule->params().size();
  string moduleName = subModule->prettyString();
  // if no parameters for layer then skip loading weights for it
  if (nParams == 0) {
    FL_LOG(fl::INFO) << "[LoadModule]: Skip loading params for " << moduleName;
    return;
  }

  if (dynamic_pointer_cast<fl::Sequential>(subModule) != nullptr) {
    // in the sequential block
    FL_LOG(fl::INFO) << "[LoadModule]: Load sequential block " << moduleName;
    auto moduleCast = dynamic_pointer_cast<fl::Sequential>(subModule);
    auto submodules = moduleCast->modules();
    for (auto smd : submodules) {
      loadModule(states, mainModule, smd, loadIdx, paramIdx);
      paramIdx += smd->params().size();
    }
  } else if (dynamic_pointer_cast<fl::Residual>(subModule) != nullptr) {
    // in the res block
    FL_LOG(fl::INFO) << "[LoadModule]: Load residual block " << moduleName;
    auto moduleCast = dynamic_pointer_cast<fl::Residual>(subModule);
    auto submodules = moduleCast->modules();
    auto projectionIndices = moduleCast->getProjectionsIndices();
    std::vector<int64_t> cumParamSize(submodules.size());
    for (int ind = 0; ind < submodules.size(); ind++) {
      if (ind > 0) {
        cumParamSize[ind] =
            cumParamSize[ind - 1] + submodules[ind - 1]->params().size();
      }
      // load modules before loading projection matrices
      if (projectionIndices.find(ind) == projectionIndices.end()) {
        loadModule(
            states,
            mainModule,
            submodules[ind],
            loadIdx,
            paramIdx + cumParamSize[ind]);
      }
    }
    for (int ind = 0; ind < submodules.size(); ind++) {
      if (projectionIndices.find(ind) != projectionIndices.end()) {
        loadModule(
            states,
            mainModule,
            submodules[ind],
            loadIdx,
            paramIdx + cumParamSize[ind]);
      }
    }
  } else if (dynamic_pointer_cast<fl::AdaptiveSoftMaxLoss>(subModule)) {
    FL_LOG(fl::INFO) << "[LoadModule]: Load adaptive softmax loss "
                     << moduleName;
    vector<int> moduleStateIndices(subModule->params().size());
    std::iota(moduleStateIndices.begin(), moduleStateIndices.end(), loadIdx);
    loadIdx += subModule->params().size();
    loadLayer(
        states,
        moduleStateIndices,
        mainModule,
        subModule,
        moduleName,
        paramIdx);
  } else {
    // collect indices for all weights corresponding to the same layer name
    FL_LOG_IF(fl::FATAL, loadIdx >= states.size())
        << "[LoadModule]: states index is out of range";
    string loadModuleName = states[loadIdx].layerName;
    vector<int> moduleStateIndices({loadIdx++});
    while ((loadIdx < states.size()) &&
           (states[loadIdx].layerName == loadModuleName)) {
      moduleStateIndices.push_back(loadIdx);
      loadIdx++;
    }
    FL_LOG(fl::INFO) << "[LoadModule]: Load module " << loadModuleName
                     << " into " << moduleName;
    loadLayer(
        states,
        moduleStateIndices,
        mainModule,
        subModule,
        moduleName,
        paramIdx);
  }
}