void loadLayer()

in recipes/utilities/convlm_serializer/Utils.cpp [87:158]


void loadLayer(
    vector<ConvLMParamState>& states,
    vector<int>& layerIndices,
    shared_ptr<fl::Module> mainModule,
    shared_ptr<fl::Module> layer,
    string layerName,
    int paramIdx) {
  auto isConvLayer = [&layer]() {
    return dynamic_pointer_cast<fl::Conv2D>(layer) ||
        (dynamic_pointer_cast<fl::WeightNorm>(layer) &&
         layer->prettyString().find("Conv2D") != std::string::npos);
  };

  bool useGrad = false;
  int nParams = layer->params().size();
  int setIdx = -1;
  for (auto idx : layerIndices) {
    FL_LOG_IF(fl::FATAL, idx >= states.size())
        << "[LoadLayer]: states index is out of range";
    FL_LOG(fl::INFO) << "[LoadLayer]: load layer with param "
                     << states[idx].paramName << " "
                     << states[idx].weights.dims();
    Variable weights;
    if (states[idx].paramName == "weight") {
      setIdx++;
      if (dynamic_pointer_cast<fl::Embedding>(layer) ||
          dynamic_pointer_cast<fl::Linear>(
              layer)) { // a hack to load the embedding layer as a linear layer
        weights = Variable(states[idx].weights.T(), useGrad);
      } else {
        weights = Variable(states[idx].weights, useGrad);
      }
    } else if (states[idx].paramName == "weight_v") {
      setIdx = 0;
      if (isConvLayer()) {
        weights = reorder(Variable(states[idx].weights, useGrad), 0, 3, 1, 2);
      } else {
        weights = Variable(states[idx].weights, useGrad);
      }
    } else if (states[idx].paramName == "weight_g") {
      setIdx = 1;
      if (isConvLayer()) {
        weights = reorder(Variable(states[idx].weights, useGrad), 0, 3, 1, 2);
      } else {
        weights = Variable(states[idx].weights, useGrad);
      }
    } else if (states[idx].paramName == "bias") {
      setIdx = layer->params().size() - 1;
      if (isConvLayer()) {
        weights = reorder(Variable(states[idx].weights, useGrad), 1, 2, 0, 3);
      } else {
        weights = Variable(states[idx].weights, useGrad);
      }
    } else {
      FL_LOG(fl::FATAL) << "[LoadLayer]: Unknown weights param "
                        << states[idx].paramName << " for file layer "
                        << states[idx].layerName
                        << " during loading weights into the model";
    }
    FL_LOG_IF(fl::FATAL, setIdx >= nParams)
        << "[LoadLayer]: Incorrect index of parameter for the file layer "
        << states[idx].layerName << ". There are " << nParams
        << " parameters in the module "
        << " but you are trying to set parameter with index " << setIdx;
    FL_LOG_IF(fl::FATAL, weights.dims() != layer->params()[setIdx].dims())
        << "[CheckSetParams]: The state provides incorrect dimensions for weight tensor."
        << " Loading (layer " << states[idx].paramName
        << ") param dim: " << weights.dims() << " Layer (" << layerName
        << ") param dim: " << layer->params()[setIdx].dims();
    mainModule->setParams(weights, setIdx + paramIdx);
  }
}