in tools/train/source/exec/transformerExecution.cpp [37:338]
int main(int argc, const char* argv[]) {
if (argc < 4) {
MNN_PRINT("Usage: ./transformer.out temp.bin dst.bin config.json [revertInfo.json]\n");
return 0;
}
std::string revertConfigFile = "revert.json";
if (argc >= 5) {
revertConfigFile = argv[4];
}
FUNC_PRINT_ALL(revertConfigFile.c_str(), s);
rapidjson::Document document;
{
std::ifstream fileNames(argv[3]);
std::ostringstream output;
output << fileNames.rdbuf();
auto outputStr = output.str();
document.Parse(outputStr.c_str());
if (document.HasParseError()) {
MNN_ERROR("Invalid json\n");
return 0;
}
FUNC_PRINT(document.HasParseError());
FUNC_PRINT(document.IsArray());
FUNC_PRINT(document.IsObject());
}
auto configObject = document.GetObject();
std::vector<std::string> noUpdateOps;
std::vector<std::string> onlyUpdateOps;
std::vector<std::string> stopBackPropOps;
std::string optimizerType = "SGD";
std::vector<std::string> fixAsConstOps;
std::vector<std::vector<std::string>> weightNameGroups;
std::vector<MNN::Express::VARP> lrNames;
if (configObject.HasMember("Optimizer")) {
auto optimizer = configObject["Optimizer"].GetObject();
if (optimizer.HasMember("OnlyUpdateOps")) {
auto limitArray = optimizer["OnlyUpdateOps"].GetArray();
for (auto vIter = limitArray.begin(); vIter != limitArray.end(); vIter++) {
onlyUpdateOps.emplace_back(vIter->GetString());
MNN_PRINT("will only update: %s \n", vIter->GetString());
}
}
if (optimizer.HasMember("NoUpdateOps")) {
auto limitArray = optimizer["NoUpdateOps"].GetArray();
for (auto vIter = limitArray.begin(); vIter != limitArray.end(); vIter++) {
noUpdateOps.emplace_back(vIter->GetString());
if (onlyUpdateOps.empty())
MNN_PRINT("will not update: %s \n", vIter->GetString());
}
}
if (optimizer.HasMember("StopBackPropOps")) {
auto limitArray = optimizer["StopBackPropOps"].GetArray();
for (auto vIter = limitArray.begin(); vIter != limitArray.end(); vIter++) {
stopBackPropOps.emplace_back(vIter->GetString());
MNN_PRINT("will stop back prop from (also not update this op): %s \n", vIter->GetString());
}
}
if (optimizer.HasMember("type")) {
optimizerType = std::string(optimizer["type"].GetString());
MNN_PRINT("optimizer type: %s\n", optimizerType.c_str());
}
if (optimizer.HasMember("FixAsConstOps")) {
auto limitArray = optimizer["FixAsConstOps"].GetArray();
for (auto vIter = limitArray.begin(); vIter != limitArray.end(); vIter++) {
fixAsConstOps.emplace_back(vIter->GetString());
MNN_PRINT("this op will be fixed as Const, and maybe turn to Trainable later: %s \n", vIter->GetString());
}
}
if (optimizer.HasMember("ParameterOptConfig")) {
auto pConf = optimizer["ParameterOptConfig"].GetArray();
for (auto vIter = pConf.begin(); vIter != pConf.end(); vIter++) {
auto conf = vIter->GetObject();
if (conf.HasMember("WeightNames") && conf.HasMember("LrName")) {
auto wn = conf["WeightNames"].GetArray();
std::vector<std::string> wNames;
for (auto wIter = wn.begin(); wIter != wn.end(); wIter++) {
wNames.push_back(wIter->GetString());
}
weightNameGroups.push_back(wNames);
auto lr = _Input({}, NCHW);
lr->setName(conf["LrName"].GetString());
lrNames.push_back(lr);
}
}
}
}
auto bnMomentum = new MNN::AttributeT;
bnMomentum->f = 0.99;
if (configObject.HasMember("BatchNorm")) {
auto bnConfig = configObject["BatchNorm"].GetObject();
if (bnConfig.HasMember("momentum")) {
bnMomentum->f = bnConfig["momentum"].GetFloat();
}
}
const char* inputModeFileName = argv[1];
FUNC_PRINT_ALL(inputModeFileName, s);
std::map<std::string, VARP> inputVars;
std::map<std::string, VARP> outputVars;
MNN::Usage netUsage;
{
// Load usage
std::shared_ptr<MNN::Interpreter> net(MNN::Interpreter::createFromFile(argv[1]));
auto buffer = net->getModelBuffer();
auto netStruct = flatbuffers::GetRoot<MNN::Net>(buffer.first);
netUsage = netStruct->usage();
}
if (Usage_INFERENCE_STATIC == netUsage) {
Executor::getGlobalExecutor()->setLazyComputeMode(MNN::Express::Executor::LAZY_CONTENT);
}
{
auto inputsOutputs = Variable::getInputAndOutput(Variable::loadMap(argv[1]));
inputVars = inputsOutputs.first;
outputVars = inputsOutputs.second;
}
for (auto& varIter : inputVars) {
auto var = varIter.second;
auto varInfo = var->getInfo();
auto vDims = varInfo->dim;
if (!fixAsConstOps.empty()) {
if (std::find(fixAsConstOps.begin(), fixAsConstOps.end(), var->name()) != fixAsConstOps.end()) {
var.fix(VARP::CONSTANT);
}
}
}
Transformer::TrainConfig trainConfig;
trainConfig.noUpdateOps = std::move(noUpdateOps);
trainConfig.onlyUpdateOps = std::move(onlyUpdateOps);
trainConfig.extraParams["BatchNorm"]["momentum"] = bnMomentum;
auto turnTrainable = Train::TurnTrainable(trainConfig);
turnTrainable.onExecute(Variable::mapToSequence(outputVars));
{
// Save Train Revert Info
std::unique_ptr<MNNTrain::TrainInfoT> trainInfo(new MNNTrain::TrainInfoT);
for (auto& bnIter : turnTrainable.mTrainInfo.bnVariables) {
std::unique_ptr<MNNTrain::KVT> kv(new MNNTrain::KVT);
kv->key = bnIter.first;
kv->value = bnIter.second->name();
trainInfo->batchnormal.emplace_back(std::move(kv));
}
for (auto& iter : turnTrainable.mTrainInfo.trainables) {
std::unique_ptr<MNNTrain::KVT> kv(new MNNTrain::KVT);
kv->key = iter.first;
kv->value = iter.second;
trainInfo->trainables.emplace_back(std::move(kv));
}
for (auto& iter : turnTrainable.mTrainInfo.convolutionVariables) {
std::unique_ptr<MNNTrain::OpInfoT> kv(new MNNTrain::OpInfoT);
kv->op = iter.first;
kv->weight = iter.second.first;
kv->bias = iter.second.second;
trainInfo->convolutions.emplace_back(std::move(kv));
}
flatbuffers::FlatBufferBuilder builder;
builder.Finish(MNNTrain::TrainInfo::Pack(builder, trainInfo.get()));
std::ofstream _t(revertConfigFile.c_str());
auto s = flatbuffers::FlatBufferToString((const uint8_t*)builder.GetBufferPointer(), MNNTrain::TrainInfoTypeTable());
_t << s;
}
auto trainInfo = turnTrainable.mTrainInfo.bnVariables;
if (configObject.HasMember("Shape")) {
auto shapeArray = configObject["Shape"].GetObject();
for (auto shapeIter = shapeArray.begin(); shapeIter != shapeArray.end(); shapeIter++) {
auto dimArray = shapeIter->value.GetArray();
std::vector<int> dims;
for (auto dimIter = dimArray.begin(); dimIter != dimArray.end(); dimIter++) {
dims.emplace_back(dimIter->GetInt());
}
FUNC_PRINT_ALL(shapeIter->name.GetString(), s);
std::string key = shapeIter->name.GetString();
for (auto& varIter : inputVars) {
auto var = varIter.second;
if (var->name() == key) {
var->resize(dims);
break;
}
}
}
}
auto exprs = Variable::getExecuteOrder(Variable::mapToSequence(outputVars));
// Collect Const Variable
std::set<VARP> parameters;
for (auto v : exprs) {
if (v->get() == nullptr && VARP::TRAINABLE == v->inputType()) {
auto va = Variable::create(v, 0);
parameters.insert(va);
}
}
for (auto p : parameters) {
p.fix(VARP::CONSTANT);
}
VARP loss;
bool train = configObject.HasMember("Train");
if (!train) {
MNN_PRINT("Don't has member Train, generate grad model\n");
}
bool hasLoss = configObject.HasMember("Loss");
if (!hasLoss) {
auto output = outputVars.begin()->second;
auto outputShape = output->getInfo();
if (outputShape->order == NC4HW4) {
auto outputName = output->name();
output->setName(outputName + "Origin");
output = _Convert(output, NHWC);
outputShape = output->getInfo();
output->setName(outputName);
}
auto outputReal = _Input(outputShape->dim, outputShape->order);
outputReal->setName(output->name() + "_Compare");
#ifdef USE_ELU
auto sub = _Subtract(output, outputReal);
sub->setName(output->name() + "_Sub");
loss = (_ReduceSum(_Multiply(sub, sub), {}));
#else
auto mul = _Multiply(_Log(output), outputReal);
mul->setName(output->name() + "_Mul");
loss = _Negative(_ReduceSum(mul, {}));
#endif
auto l2 = _Const(0.0f);
for (auto var : parameters) {
l2 = l2 + (var * var).sum({});
}
loss = loss + _Multiply(l2, _Const(0.0005f));
loss->setName("Loss");
exprs = Variable::getExecuteOrder({loss});
} else {
std::string lossName = configObject["Loss"].GetObject()["op"].GetString();
for (auto expr : exprs) {
if (expr->name() == lossName) {
loss = Variable::create(expr);
break;
}
}
for (auto iter : outputVars) {
if (iter.first == lossName) {
outputVars.erase(iter.first);
break;
}
}
if (nullptr == loss.get()) {
MNN_ERROR("Can't find loss op\n");
return 0;
}
}
auto lossInfo = loss->getInfo();
MNN_ASSERT(nullptr != loss);
auto gradMap = OpGrad::grad(loss, parameters, stopBackPropOps);
if (gradMap.empty()) {
MNN_ERROR("Grad error, don't has grad\n");
return 0;
}
for (auto iter : gradMap) {
if (!iter.first->name().empty()) {
iter.second->setName(iter.first->name() + "::grad");
}
}
if (!train) {
std::vector<MNN::Express::VARP> gradVars = {loss};
for (auto iter : gradMap) {
iter.first.fix(VARP::INPUT);
gradVars.emplace_back(iter.second);
}
ParameterOptimizer::makeLoopModel(argv[2], gradVars, std::make_pair(std::vector<MNN::Express::VARP>{}, std::vector<MNN::Express::VARP>{}));
return 0;
}
// Make Update
std::shared_ptr<MNN::Train::ParameterOptimizer> optimizer;
if (optimizerType == "SGD") {
optimizer.reset(MNN::Train::ParameterOptimizer::createSGD(nullptr, 0.01f, 0.90f, 0.00f, MNN::Train::ParameterOptimizer::L1));
} else if (optimizerType == "ADAM") {
optimizer.reset(MNN::Train::ParameterOptimizer::createADAM(nullptr, 0.01f, 0.90f, 0.999f, 0.00f, 0.00005f, MNN::Train::ParameterOptimizer::L1));
}
auto learningRate = _Input({}, NCHW);
learningRate->setName("LearningRate");
std::vector<ParameterOptimizer::ParameterOptGrad> gradVars;
for (auto iter : gradMap) {
ParameterOptimizer::ParameterOptGrad gradVar;
gradVar.parameter = iter.first;
gradVar.parameterGrad = iter.second;
gradVar.learningRate = learningRate;
if (!lrNames.empty()) {
// Find lr Index
auto pName = iter.first->name();
for (int ii = 0; ii < weightNameGroups.size(); ii++) {
if (std::find(weightNameGroups[ii].begin(), weightNameGroups[ii].end(), pName) != weightNameGroups[ii].end()) {
gradVar.learningRate = lrNames[ii];
break;
}
}
}
gradVars.emplace_back(gradVar);
}
auto loopPair = optimizer->onMakeParameterUpdateGraphByGrad(gradVars);
std::unique_ptr<MNN::NetT> netStruct(new MNN::NetT);
std::vector<VARP> resultOutputs = {loss};
ParameterOptimizer::makeLoopModel(argv[2], resultOutputs, loopPair);
return 0;
}