recipes/utilities/convlm_serializer/Utils.cpp (277 lines of code) (raw):

/* * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT-style license found in the * LICENSE file in the root directory of this source tree. */ #include "recipes/utilities/convlm_serializer/Utils.h" #include <fstream> #include <flashlight/fl/contrib/contrib.h> #include <flashlight/lib/common/String.h> #include <flashlight/lib/common/System.h> #include <flashlight/pkg/runtime/common/SequentialBuilder.h> using fl::Variable; using std::dynamic_pointer_cast; using std::make_shared; using std::shared_ptr; using std::string; using std::vector; vector<ConvLMParamState> loadModelStates(const string& weightFile) { FL_LOG(fl::INFO) << "[ConvLMSerializer]: Reading pytorch model of the ConvLM"; FL_LOG_IF(fl::FATAL, !fl::lib::fileExists(weightFile)) << "Path to weight file " << weightFile << " doesn't exist"; vector<ConvLMParamState> states; std::ifstream infile(weightFile); string line; while (getline(infile, line)) { std::stringstream ss; string weightName; int nDims; int64_t totalElements = 1; ss << line; ss >> weightName >> nDims; vector<int> shapes(nDims); string shape_str = ""; for (int dim = 0; dim < nDims; dim++) { ss >> shapes[dim]; totalElements *= shapes[dim]; shape_str += std::to_string(shapes[dim]) + " "; } FL_LOG(fl::INFO) << "[LoadModelStates]: Reading state " << weightName << " with dims " << nDims << " and shape " << shape_str; vector<float> data(totalElements); for (int index = 0; index < totalElements; index++) { ss >> data[index]; } auto parts = fl::lib::splitOnAnyOf(".", weightName, true); FL_LOG_IF(fl::FATAL, parts.size() < 2) << "Param name " << weightName << " should be in format {prefix.}layerName.paramName"; vector<string> names = { fl::lib::join(".", parts.begin(), parts.end() - 2), *(parts.end() - 2), *(parts.end() - 1)}; FL_LOG_IF(fl::FATAL, names.size() != 3) << "[LoadModelStates]: Error during parsing parameter name"; af::dim4 dimensions(1, 1, 1, 1); // af has fortran-ordering (column-way) // revert axis before loading c-ordered matrices (row-way) vector<int> reordering = {0, 1, 2, 3}; FL_LOG_IF(fl::FATAL, nDims > 4) << "[loadModelStates]: Layer " << weightName << " has dimensions greater than 4. " << "This is not supported by ArrayFire"; for (int idx = nDims - 1; idx >= 0; idx--) { dimensions[nDims - 1 - idx] = shapes[idx]; reordering[nDims - 1 - idx] = idx; } af::array weights = af::array(dimensions, data.data()); weights = reorder( weights, reordering[0], reordering[1], reordering[2], reordering[3]); states.push_back({names[0], names[1], names[2], weights}); } infile.close(); return states; } void loadLayer( vector<ConvLMParamState>& states, vector<int>& layerIndices, shared_ptr<fl::Module> mainModule, shared_ptr<fl::Module> layer, string layerName, int paramIdx) { auto isConvLayer = [&layer]() { return dynamic_pointer_cast<fl::Conv2D>(layer) || (dynamic_pointer_cast<fl::WeightNorm>(layer) && layer->prettyString().find("Conv2D") != std::string::npos); }; bool useGrad = false; int nParams = layer->params().size(); int setIdx = -1; for (auto idx : layerIndices) { FL_LOG_IF(fl::FATAL, idx >= states.size()) << "[LoadLayer]: states index is out of range"; FL_LOG(fl::INFO) << "[LoadLayer]: load layer with param " << states[idx].paramName << " " << states[idx].weights.dims(); Variable weights; if (states[idx].paramName == "weight") { setIdx++; if (dynamic_pointer_cast<fl::Embedding>(layer) || dynamic_pointer_cast<fl::Linear>( layer)) { // a hack to load the embedding layer as a linear layer weights = Variable(states[idx].weights.T(), useGrad); } else { weights = Variable(states[idx].weights, useGrad); } } else if (states[idx].paramName == "weight_v") { setIdx = 0; if (isConvLayer()) { weights = reorder(Variable(states[idx].weights, useGrad), 0, 3, 1, 2); } else { weights = Variable(states[idx].weights, useGrad); } } else if (states[idx].paramName == "weight_g") { setIdx = 1; if (isConvLayer()) { weights = reorder(Variable(states[idx].weights, useGrad), 0, 3, 1, 2); } else { weights = Variable(states[idx].weights, useGrad); } } else if (states[idx].paramName == "bias") { setIdx = layer->params().size() - 1; if (isConvLayer()) { weights = reorder(Variable(states[idx].weights, useGrad), 1, 2, 0, 3); } else { weights = Variable(states[idx].weights, useGrad); } } else { FL_LOG(fl::FATAL) << "[LoadLayer]: Unknown weights param " << states[idx].paramName << " for file layer " << states[idx].layerName << " during loading weights into the model"; } FL_LOG_IF(fl::FATAL, setIdx >= nParams) << "[LoadLayer]: Incorrect index of parameter for the file layer " << states[idx].layerName << ". There are " << nParams << " parameters in the module " << " but you are trying to set parameter with index " << setIdx; FL_LOG_IF(fl::FATAL, weights.dims() != layer->params()[setIdx].dims()) << "[CheckSetParams]: The state provides incorrect dimensions for weight tensor." << " Loading (layer " << states[idx].paramName << ") param dim: " << weights.dims() << " Layer (" << layerName << ") param dim: " << layer->params()[setIdx].dims(); mainModule->setParams(weights, setIdx + paramIdx); } } void loadModule( vector<ConvLMParamState>& states, shared_ptr<fl::Module> mainModule, shared_ptr<fl::Module> subModule, int& loadIdx, int paramIdx) { int nParams = subModule->params().size(); string moduleName = subModule->prettyString(); // if no parameters for layer then skip loading weights for it if (nParams == 0) { FL_LOG(fl::INFO) << "[LoadModule]: Skip loading params for " << moduleName; return; } if (dynamic_pointer_cast<fl::Sequential>(subModule) != nullptr) { // in the sequential block FL_LOG(fl::INFO) << "[LoadModule]: Load sequential block " << moduleName; auto moduleCast = dynamic_pointer_cast<fl::Sequential>(subModule); auto submodules = moduleCast->modules(); for (auto smd : submodules) { loadModule(states, mainModule, smd, loadIdx, paramIdx); paramIdx += smd->params().size(); } } else if (dynamic_pointer_cast<fl::Residual>(subModule) != nullptr) { // in the res block FL_LOG(fl::INFO) << "[LoadModule]: Load residual block " << moduleName; auto moduleCast = dynamic_pointer_cast<fl::Residual>(subModule); auto submodules = moduleCast->modules(); auto projectionIndices = moduleCast->getProjectionsIndices(); std::vector<int64_t> cumParamSize(submodules.size()); for (int ind = 0; ind < submodules.size(); ind++) { if (ind > 0) { cumParamSize[ind] = cumParamSize[ind - 1] + submodules[ind - 1]->params().size(); } // load modules before loading projection matrices if (projectionIndices.find(ind) == projectionIndices.end()) { loadModule( states, mainModule, submodules[ind], loadIdx, paramIdx + cumParamSize[ind]); } } for (int ind = 0; ind < submodules.size(); ind++) { if (projectionIndices.find(ind) != projectionIndices.end()) { loadModule( states, mainModule, submodules[ind], loadIdx, paramIdx + cumParamSize[ind]); } } } else if (dynamic_pointer_cast<fl::AdaptiveSoftMaxLoss>(subModule)) { FL_LOG(fl::INFO) << "[LoadModule]: Load adaptive softmax loss " << moduleName; vector<int> moduleStateIndices(subModule->params().size()); std::iota(moduleStateIndices.begin(), moduleStateIndices.end(), loadIdx); loadIdx += subModule->params().size(); loadLayer( states, moduleStateIndices, mainModule, subModule, moduleName, paramIdx); } else { // collect indices for all weights corresponding to the same layer name FL_LOG_IF(fl::FATAL, loadIdx >= states.size()) << "[LoadModule]: states index is out of range"; string loadModuleName = states[loadIdx].layerName; vector<int> moduleStateIndices({loadIdx++}); while ((loadIdx < states.size()) && (states[loadIdx].layerName == loadModuleName)) { moduleStateIndices.push_back(loadIdx); loadIdx++; } FL_LOG(fl::INFO) << "[LoadModule]: Load module " << loadModuleName << " into " << moduleName; loadLayer( states, moduleStateIndices, mainModule, subModule, moduleName, paramIdx); } } void setParams( shared_ptr<fl::Module> network, shared_ptr<fl::BinaryModule> criterion, vector<ConvLMParamState>& states) { FL_LOG(fl::INFO) << "[SetParams]: Load weights into the model"; int loadIdx = 0, paramIdx = 0; auto networkCast = dynamic_pointer_cast<fl::Sequential>(network); for (auto module : networkCast->modules()) { loadModule(states, networkCast, module, loadIdx, paramIdx); paramIdx += module->params().size(); } if ((criterion != nullptr) && (criterion->params().size() > 0)) { loadModule(states, criterion, criterion, loadIdx, 0); } FL_LOG_IF(fl::FATAL, loadIdx < states.size()) << "[SetParams]: Some weights are remain in the file during loading the model"; FL_LOG(fl::INFO) << "[SetParams]: Finish loading weight from the file into the model"; } void loadConvLM( shared_ptr<fl::Module>& network, shared_ptr<fl::BinaryModule>& criterion, const string& archFile, const string& weightFile, int outputTokensDim, const vector<int>& adaptiveTail /* = std::vector<int>() */, int inputSizeAdaptiveSoftmax /* = 0 */) { FL_LOG_IF(fl::FATAL, !fl::lib::fileExists(archFile)) << "Path to arch file " << archFile << " doesn't exist"; FL_LOG_IF(fl::FATAL, !fl::lib::fileExists(weightFile)) << "Path to weight file " << weightFile << " doesn't exist"; // create network and criterion network = fl::pkg::runtime::buildSequentialModule(archFile, 1, outputTokensDim); network->eval(); if (adaptiveTail.size() > 0) { auto activation = make_shared<fl::AdaptiveSoftMax>( inputSizeAdaptiveSoftmax, adaptiveTail); criterion = make_shared<fl::AdaptiveSoftMaxLoss>(activation); criterion->eval(); } else { criterion = nullptr; } // Loading weights from the binary file FL_LOG(fl::INFO) << "[LoadConvLM]: Load states"; auto modelStates = loadModelStates(weightFile); FL_LOG_IF( fl::FATAL, modelStates.size() != network->params().size() + (criterion ? criterion->params().size() : 0)) << "mismatch between the number of parameters in the arch file and the weight file " << modelStates.size() << " model states vs " << network->params().size() << " nn params + " << (criterion ? criterion->params().size() : 0) << " criterion params"; // Load weight states into network and criterion FL_LOG(fl::INFO) << "[LoadConvLM]: set params"; setParams(network, criterion, modelStates); }