recipes/utilities/convlm_serializer/Utils.h (26 lines of code) (raw):

/* * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include <string> #include <vector> #include <flashlight/fl/flashlight.h> struct ConvLMParamState { const std::string moduleName; const std::string layerName; const std::string paramName; af::array weights; }; std::vector<ConvLMParamState> loadModelStates(const std::string& weightFile); void loadLayer( std::vector<ConvLMParamState>& states, std::vector<int>& layerIndices, std::shared_ptr<fl::Module> mainModule, std::shared_ptr<fl::Module> layer, std::string layerName, int paramIdx); void loadConvLM( std::shared_ptr<fl::Module>& network, std::shared_ptr<fl::BinaryModule>& criterion, const std::string& arcFile, const std::string& weightFile, int outputTokensDim, const std::vector<int>& adaptiveTail = std::vector<int>(), int inputSizeAdaptiveSoftmax = 0);