in recipes/utilities/convlm_serializer/Utils.cpp [160:249]
void loadModule(
vector<ConvLMParamState>& states,
shared_ptr<fl::Module> mainModule,
shared_ptr<fl::Module> subModule,
int& loadIdx,
int paramIdx) {
int nParams = subModule->params().size();
string moduleName = subModule->prettyString();
// if no parameters for layer then skip loading weights for it
if (nParams == 0) {
FL_LOG(fl::INFO) << "[LoadModule]: Skip loading params for " << moduleName;
return;
}
if (dynamic_pointer_cast<fl::Sequential>(subModule) != nullptr) {
// in the sequential block
FL_LOG(fl::INFO) << "[LoadModule]: Load sequential block " << moduleName;
auto moduleCast = dynamic_pointer_cast<fl::Sequential>(subModule);
auto submodules = moduleCast->modules();
for (auto smd : submodules) {
loadModule(states, mainModule, smd, loadIdx, paramIdx);
paramIdx += smd->params().size();
}
} else if (dynamic_pointer_cast<fl::Residual>(subModule) != nullptr) {
// in the res block
FL_LOG(fl::INFO) << "[LoadModule]: Load residual block " << moduleName;
auto moduleCast = dynamic_pointer_cast<fl::Residual>(subModule);
auto submodules = moduleCast->modules();
auto projectionIndices = moduleCast->getProjectionsIndices();
std::vector<int64_t> cumParamSize(submodules.size());
for (int ind = 0; ind < submodules.size(); ind++) {
if (ind > 0) {
cumParamSize[ind] =
cumParamSize[ind - 1] + submodules[ind - 1]->params().size();
}
// load modules before loading projection matrices
if (projectionIndices.find(ind) == projectionIndices.end()) {
loadModule(
states,
mainModule,
submodules[ind],
loadIdx,
paramIdx + cumParamSize[ind]);
}
}
for (int ind = 0; ind < submodules.size(); ind++) {
if (projectionIndices.find(ind) != projectionIndices.end()) {
loadModule(
states,
mainModule,
submodules[ind],
loadIdx,
paramIdx + cumParamSize[ind]);
}
}
} else if (dynamic_pointer_cast<fl::AdaptiveSoftMaxLoss>(subModule)) {
FL_LOG(fl::INFO) << "[LoadModule]: Load adaptive softmax loss "
<< moduleName;
vector<int> moduleStateIndices(subModule->params().size());
std::iota(moduleStateIndices.begin(), moduleStateIndices.end(), loadIdx);
loadIdx += subModule->params().size();
loadLayer(
states,
moduleStateIndices,
mainModule,
subModule,
moduleName,
paramIdx);
} else {
// collect indices for all weights corresponding to the same layer name
FL_LOG_IF(fl::FATAL, loadIdx >= states.size())
<< "[LoadModule]: states index is out of range";
string loadModuleName = states[loadIdx].layerName;
vector<int> moduleStateIndices({loadIdx++});
while ((loadIdx < states.size()) &&
(states[loadIdx].layerName == loadModuleName)) {
moduleStateIndices.push_back(loadIdx);
loadIdx++;
}
FL_LOG(fl::INFO) << "[LoadModule]: Load module " << loadModuleName
<< " into " << moduleName;
loadLayer(
states,
moduleStateIndices,
mainModule,
subModule,
moduleName,
paramIdx);
}
}