in recipes/utilities/convlm_serializer/Utils.cpp [87:158]
void loadLayer(
vector<ConvLMParamState>& states,
vector<int>& layerIndices,
shared_ptr<fl::Module> mainModule,
shared_ptr<fl::Module> layer,
string layerName,
int paramIdx) {
auto isConvLayer = [&layer]() {
return dynamic_pointer_cast<fl::Conv2D>(layer) ||
(dynamic_pointer_cast<fl::WeightNorm>(layer) &&
layer->prettyString().find("Conv2D") != std::string::npos);
};
bool useGrad = false;
int nParams = layer->params().size();
int setIdx = -1;
for (auto idx : layerIndices) {
FL_LOG_IF(fl::FATAL, idx >= states.size())
<< "[LoadLayer]: states index is out of range";
FL_LOG(fl::INFO) << "[LoadLayer]: load layer with param "
<< states[idx].paramName << " "
<< states[idx].weights.dims();
Variable weights;
if (states[idx].paramName == "weight") {
setIdx++;
if (dynamic_pointer_cast<fl::Embedding>(layer) ||
dynamic_pointer_cast<fl::Linear>(
layer)) { // a hack to load the embedding layer as a linear layer
weights = Variable(states[idx].weights.T(), useGrad);
} else {
weights = Variable(states[idx].weights, useGrad);
}
} else if (states[idx].paramName == "weight_v") {
setIdx = 0;
if (isConvLayer()) {
weights = reorder(Variable(states[idx].weights, useGrad), 0, 3, 1, 2);
} else {
weights = Variable(states[idx].weights, useGrad);
}
} else if (states[idx].paramName == "weight_g") {
setIdx = 1;
if (isConvLayer()) {
weights = reorder(Variable(states[idx].weights, useGrad), 0, 3, 1, 2);
} else {
weights = Variable(states[idx].weights, useGrad);
}
} else if (states[idx].paramName == "bias") {
setIdx = layer->params().size() - 1;
if (isConvLayer()) {
weights = reorder(Variable(states[idx].weights, useGrad), 1, 2, 0, 3);
} else {
weights = Variable(states[idx].weights, useGrad);
}
} else {
FL_LOG(fl::FATAL) << "[LoadLayer]: Unknown weights param "
<< states[idx].paramName << " for file layer "
<< states[idx].layerName
<< " during loading weights into the model";
}
FL_LOG_IF(fl::FATAL, setIdx >= nParams)
<< "[LoadLayer]: Incorrect index of parameter for the file layer "
<< states[idx].layerName << ". There are " << nParams
<< " parameters in the module "
<< " but you are trying to set parameter with index " << setIdx;
FL_LOG_IF(fl::FATAL, weights.dims() != layer->params()[setIdx].dims())
<< "[CheckSetParams]: The state provides incorrect dimensions for weight tensor."
<< " Loading (layer " << states[idx].paramName
<< ") param dim: " << weights.dims() << " Layer (" << layerName
<< ") param dim: " << layer->params()[setIdx].dims();
mainModule->setParams(weights, setIdx + paramIdx);
}
}