in recipes/utilities/convlm_serializer/Utils.cpp [24:85]
vector<ConvLMParamState> loadModelStates(const string& weightFile) {
FL_LOG(fl::INFO) << "[ConvLMSerializer]: Reading pytorch model of the ConvLM";
FL_LOG_IF(fl::FATAL, !fl::lib::fileExists(weightFile))
<< "Path to weight file " << weightFile << " doesn't exist";
vector<ConvLMParamState> states;
std::ifstream infile(weightFile);
string line;
while (getline(infile, line)) {
std::stringstream ss;
string weightName;
int nDims;
int64_t totalElements = 1;
ss << line;
ss >> weightName >> nDims;
vector<int> shapes(nDims);
string shape_str = "";
for (int dim = 0; dim < nDims; dim++) {
ss >> shapes[dim];
totalElements *= shapes[dim];
shape_str += std::to_string(shapes[dim]) + " ";
}
FL_LOG(fl::INFO) << "[LoadModelStates]: Reading state " << weightName
<< " with dims " << nDims << " and shape " << shape_str;
vector<float> data(totalElements);
for (int index = 0; index < totalElements; index++) {
ss >> data[index];
}
auto parts = fl::lib::splitOnAnyOf(".", weightName, true);
FL_LOG_IF(fl::FATAL, parts.size() < 2)
<< "Param name " << weightName
<< " should be in format {prefix.}layerName.paramName";
vector<string> names = {
fl::lib::join(".", parts.begin(), parts.end() - 2),
*(parts.end() - 2),
*(parts.end() - 1)};
FL_LOG_IF(fl::FATAL, names.size() != 3)
<< "[LoadModelStates]: Error during parsing parameter name";
af::dim4 dimensions(1, 1, 1, 1);
// af has fortran-ordering (column-way)
// revert axis before loading c-ordered matrices (row-way)
vector<int> reordering = {0, 1, 2, 3};
FL_LOG_IF(fl::FATAL, nDims > 4) << "[loadModelStates]: Layer " << weightName
<< " has dimensions greater than 4. "
<< "This is not supported by ArrayFire";
for (int idx = nDims - 1; idx >= 0; idx--) {
dimensions[nDims - 1 - idx] = shapes[idx];
reordering[nDims - 1 - idx] = idx;
}
af::array weights = af::array(dimensions, data.data());
weights = reorder(
weights, reordering[0], reordering[1], reordering[2], reordering[3]);
states.push_back({names[0], names[1], names[2], weights});
}
infile.close();
return states;
}