vector loadModelStates()

in recipes/utilities/convlm_serializer/Utils.cpp [24:85]


vector<ConvLMParamState> loadModelStates(const string& weightFile) {
  FL_LOG(fl::INFO) << "[ConvLMSerializer]: Reading pytorch model of the ConvLM";
  FL_LOG_IF(fl::FATAL, !fl::lib::fileExists(weightFile))
      << "Path to weight file " << weightFile << " doesn't exist";

  vector<ConvLMParamState> states;
  std::ifstream infile(weightFile);
  string line;
  while (getline(infile, line)) {
    std::stringstream ss;
    string weightName;
    int nDims;
    int64_t totalElements = 1;
    ss << line;
    ss >> weightName >> nDims;

    vector<int> shapes(nDims);
    string shape_str = "";
    for (int dim = 0; dim < nDims; dim++) {
      ss >> shapes[dim];
      totalElements *= shapes[dim];
      shape_str += std::to_string(shapes[dim]) + " ";
    }
    FL_LOG(fl::INFO) << "[LoadModelStates]: Reading state " << weightName
                     << " with dims " << nDims << " and shape " << shape_str;
    vector<float> data(totalElements);
    for (int index = 0; index < totalElements; index++) {
      ss >> data[index];
    }

    auto parts = fl::lib::splitOnAnyOf(".", weightName, true);
    FL_LOG_IF(fl::FATAL, parts.size() < 2)
        << "Param name " << weightName
        << " should be in format {prefix.}layerName.paramName";
    vector<string> names = {
        fl::lib::join(".", parts.begin(), parts.end() - 2),
        *(parts.end() - 2),
        *(parts.end() - 1)};

    FL_LOG_IF(fl::FATAL, names.size() != 3)
        << "[LoadModelStates]: Error during parsing parameter name";

    af::dim4 dimensions(1, 1, 1, 1);
    // af has fortran-ordering (column-way)
    // revert axis before loading c-ordered matrices (row-way)
    vector<int> reordering = {0, 1, 2, 3};
    FL_LOG_IF(fl::FATAL, nDims > 4) << "[loadModelStates]: Layer " << weightName
                                    << " has dimensions greater than 4. "
                                    << "This is not supported by ArrayFire";
    for (int idx = nDims - 1; idx >= 0; idx--) {
      dimensions[nDims - 1 - idx] = shapes[idx];
      reordering[nDims - 1 - idx] = idx;
    }
    af::array weights = af::array(dimensions, data.data());
    weights = reorder(
        weights, reordering[0], reordering[1], reordering[2], reordering[3]);
    states.push_back({names[0], names[1], names[2], weights});
  }
  infile.close();

  return states;
}