int main()

in recipes/streaming_convnets/tools/StreamingTDSModelConverter.cpp [140:376]


int main(int argc, char** argv) {
  google::InitGoogleLogging(argv[0]);
  google::InstallFailureSignalHandler();
  /* ===================== Parse Options ===================== */
  LOG(INFO) << "Parsing command line flags";
  gflags::ParseCommandLineFlags(&argc, &argv, false);

  /* ===================== Create Network ===================== */
  std::shared_ptr<fl::Module> network;
  std::shared_ptr<SequenceCriterion> criterion;
  std::unordered_map<std::string, std::string> cfg;
  std::string version;
  LOG(INFO) << "[Network] Reading acoustic model from " << FLAGS_am;
  fl::ext::Serializer::load(FLAGS_am, version, cfg, network, criterion);
  network->eval();
  criterion->eval();

  LOG(INFO) << "[Network] " << network->prettyString();
  LOG(INFO) << "[Criterion] " << criterion->prettyString();
  LOG(INFO) << "[Network] Number of params: " << numTotalParams(network);

  auto flags = cfg.find(kGflags);
  if (flags == cfg.end()) {
    LOG(FATAL) << "[Network] Invalid config loaded from " << FLAGS_am;
  }
  LOG(INFO) << "[Network] Updating flags from config file: " << FLAGS_am;
  gflags::ReadFlagsFromString(flags->second, gflags::GetArgv0(), true);

  // override with user-specified flags
  gflags::ParseCommandLineFlags(&argc, &argv, false);
  if (!FLAGS_flagsfile.empty()) {
    gflags::ReadFromFlagsFile(FLAGS_flagsfile, argv[0], true);
  }

  LOG(INFO) << "Gflags after parsing \n" << serializeGflags("; ");

  /* ===================== Create Dictionary ===================== */
  auto dictPath = fl::lib::pathsConcat(FLAGS_tokensdir, FLAGS_tokens);
  if (dictPath.empty() || !fileExists(dictPath)) {
    throw std::runtime_error(
        "Invalid dictionary filepath specified " + dictPath);
  }
  text::Dictionary tokenDict(dictPath);
  for (int64_t r = 1; r <= FLAGS_replabel; ++r) {
    tokenDict.addEntry("<" + std::to_string(r) + ">");
  }
  if (FLAGS_criterion == kCtcCriterion) {
    tokenDict.addEntry(kBlankToken);
  } else if (FLAGS_criterion != kAsgCriterion) {
    LOG(FATAL) << "This script currently support only CTC/ASG criterion";
  }
  int numTokens = tokenDict.indexSize();
  LOG(INFO) << "Number of classes (network): " << numTokens;

  int nFeat = 0;
  if (FLAGS_mfsc) {
    nFeat = FLAGS_filterbanks;
  } else {
    LOG(FATAL) << "This script currently support only mfsc features";
  }

  auto lines = getFileContent(fl::lib::pathsConcat(FLAGS_archdir, FLAGS_arch));

  auto streamingModule = std::make_shared<streaming::Sequential>();
  auto params = network->params();
  int curFeatSz = nFeat;
  int paramIdx = 0;
  int leftPad = -1, rightPad = -1;
  for (size_t i = 0; i < lines.size(); ++i) {
    auto columns = splitOnWhitespace(lines[i], true);
    if (columns.empty()) {
      continue;
    }
    auto layerType = columns[0];
    if (layerType == "C2") {
      if (columns.size() < 8) {
        LOG(FATAL) << "Invalid arch specified for C2";
      }
      auto conv1d = convertConv1d(
          std::stoi(columns[1]) * nFeat,
          std::stoi(columns[2]) * nFeat,
          std::stoi(columns[3]),
          std::stoi(columns[5]),
          {leftPad, rightPad},
          nFeat,
          params[paramIdx],
          params[paramIdx + 1]);
      streamingModule->add(conv1d);
      leftPad = -1;
      rightPad = -1;
      paramIdx += 2;
      curFeatSz = std::stoi(columns[2]) * nFeat;
    } else if (layerType == "PD") {
      if (columns.size() != 4) {
        LOG(FATAL) << "Padding is supported only along time axis";
      }
      if (!startsWith(lines[i + 1], "C2")) {
        LOG(FATAL) << "Padding layer must be followed by conv layer";
      }
      leftPad = std::stoi(columns[2]);
      rightPad = std::stoi(columns[3]);
    } else if (layerType == "R") {
      streamingModule->add(
          std::make_shared<streaming::Relu>(streaming::DataType::FLOAT));
    } else if (layerType == "LN") {
      if (columns[1] != "1" || columns[2] != "2") {
        LOG(FATAL)
            << "Unsupported LayerNorm axis: must be {1, 2} for streaming";
      }
      auto lyrNorm =
          convertLayerNorm(curFeatSz, params[paramIdx], params[paramIdx + 1]);
      streamingModule->add(lyrNorm);
      paramIdx += 2;
    } else if (layerType == "L") {
      int outDim = (columns[2] == "NLABEL") ? numTokens : std::stoi(columns[2]);
      int inDim = std::stoi(columns[1]);
      if (params.size() < paramIdx + 2) {
        LOG(FATAL) << "Error serializing Linear module. Not enough parameters.";
      }
      auto linear =
          convertLinear(inDim, outDim, params[paramIdx], params[paramIdx + 1]);
      streamingModule->add(linear);
      paramIdx += 2;
    } else if (layerType == "TDS") {
      auto stds = convertTDS(
          std::stoi(columns[1]),
          std::stoi(columns[2]),
          std::stoi(columns[3]),
          (columns.size() > 6) ? std::stoi(columns[6]) : -1,
          {params.begin() + paramIdx, params.begin() + paramIdx + 10},
          (columns.size() > 5) ? std::stoi(columns[5]) : 0);
      streamingModule->add(stds);
      paramIdx += 10;
    } else if (layerType == "V") {
      std::cerr << "Skipping View module: " << lines[i] << std::endl;
    } else if (layerType == "RO") {
      std::cerr << "Skipping Reorder module: " << lines[i] << std::endl;
    } else if (layerType == "DO") {
      std::cerr << "Skipping Dropout module: " << lines[i] << std::endl;
    } else if (layerType == "SAUG") {
      std::cerr << "Skipping SpecAugment module: " << lines[i] << std::endl;
    } else if (layerType != "#") {
      throw std::logic_error("Unrecognized/unparsable line " + lines[i]);
    }
  }

  {
    std::string amFilePath =
        fl::lib::pathsConcat(FLAGS_outdir, "acoustic_model.bin");
    std::ofstream amFile(amFilePath, std::ios::binary);
    LOG(INFO) << "Serializing acoustic model to '" << amFilePath << "'";

    if (!amFile.is_open()) {
      throw std::runtime_error("failed to open file for reading");
    }
    cereal::BinaryOutputArchive ar(amFile);
    ar(streamingModule);
  }

  {
    std::string tokenFilePath =
        fl::lib::pathsConcat(FLAGS_outdir, "tokens.txt");
    std::ofstream tokenFile(tokenFilePath);
    LOG(INFO) << "Writing tokens file to '" << tokenFilePath << "'";
    for (int i = 0; i < tokenDict.indexSize(); ++i) {
      tokenFile << tokenDict.getEntry(i) << "\n";
    }
    tokenFile.close();
  }

  if (FLAGS_criterion == kAsgCriterion) {
    if (criterion->params().size() == 0 ||
        criterion->param(0).elements() !=
            tokenDict.indexSize() * tokenDict.indexSize()) {
      throw std::runtime_error("Invalid criterion parameters for ASG");
    }
    std::string transitionsFilePath =
        fl::lib::pathsConcat(FLAGS_outdir, "transitions.bin");
    std::ofstream transitionsFile(transitionsFilePath);
    LOG(INFO) << "Writing transitions file to '" << transitionsFilePath << "'";
    std::vector<float> transitionsVec(criterion->param(0).elements());
    criterion->param(0).host(transitionsVec.data());
    fl::save(transitionsFile, transitionsVec);
    transitionsFile.close();
  }

  {
    std::string featFilePath =
        fl::lib::pathsConcat(FLAGS_outdir, "feature_extractor.bin");
    std::ofstream featFile(featFilePath, std::ios::binary);
    LOG(INFO) << "Serializing feature extraction model to '" << featFilePath
              << "'";
    auto featureModule = std::make_shared<streaming::Sequential>();
    featureModule->add(std::make_shared<streaming::LogMelFeature>(nFeat));
    LOG_IF(FATAL, FLAGS_localnrmlleftctx <= 0)
        << "Local Norm should be used for online inference";
    featureModule->add(std::make_shared<streaming::LocalNorm>(
        nFeat, FLAGS_localnrmlleftctx, FLAGS_localnrmlrightctx));

    if (!featFile.is_open()) {
      throw std::runtime_error("failed to open file for reading");
    }
    cereal::BinaryOutputArchive ar(featFile);
    ar(featureModule);
  }

  LOG(INFO) << "verifying serialization ...";
  af::array inputArr = af::randu(16, 80);
  auto outputArr = network->forward({fl::Variable(inputArr, false)})[0].array();
  std::vector<float> outputVec(outputArr.elements());
  outputArr.host(outputVec.data());

  std::vector<float> inputVec(inputArr.elements());
  af::reorder(inputArr, 2, 1, 0).host(inputVec.data());
  auto inputState = std::make_shared<streaming::ModuleProcessingState>(1);
  std::shared_ptr<streaming::IOBuffer> inputBuffer = inputState->buffer(0);
  inputBuffer->write<float>(inputVec.data(), inputVec.size());

  streamingModule->start(inputState);
  auto outputState = streamingModule->run(inputState);
  streamingModule->finish(inputState);

  std::shared_ptr<streaming::IOBuffer> outputBuffer = outputState->buffer(0);

  LOG_IF(FATAL, outputBuffer->size<float>() != outputVec.size())
      << "[Serialization Error] Incorrect output sizes";
  float* outPtr = outputBuffer->data<float>();
  for (int i = 0; i < outputBuffer->size<float>(); i++) {
    float streamingOut = outPtr[i];
    float w2lOut = outputVec[i];
    LOG_IF(ERROR, fabs(streamingOut - w2lOut) > 1e-2)
        << "[Serialization Error] Mismatched output w2l:" << w2lOut
        << " vs streaming:" << streamingOut;
  }
  LOG(INFO) << "Done !";
  return 0;
}