tflite::support::StatusOr ConvertSentencepieceModelToFlatBuffer()

in tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc [50:127]


tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
    const std::string& model_config_str, int encoding_offset) {
  ::sentencepiece::ModelProto model_config;
  if (!model_config.ParseFromString(model_config_str)) {
    return absl::InvalidArgumentError(
        "Invalid configuration, can't parse SentencePiece model config " +
        model_config.InitializationErrorString());
  }
  // Convert sentencepieces.
  std::vector<std::string> pieces;
  pieces.reserve(model_config.pieces_size());
  std::vector<float> scores;
  scores.reserve(model_config.pieces_size());
  std::vector<int> ids;
  ids.reserve(model_config.pieces_size());
  float min_score = 0.0;
  int index = 0;
  for (const auto& piece : model_config.pieces()) {
    switch (piece.type()) {
      case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
      case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
        pieces.push_back(piece.piece());
        ids.push_back(index);
        if (piece.score() < min_score) {
          min_score = piece.score();
        }
        break;
      case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
      case ::sentencepiece::ModelProto::SentencePiece::CONTROL:
        // Ignore unknown and control codes.
        break;
      default:
        return absl::InvalidArgumentError("Invalid SentencePiece piece type " +
                                          piece.piece());
    }
    scores.push_back(piece.score());
    ++index;
  }
  flatbuffers::FlatBufferBuilder builder(1024);
  const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids));
  const auto pieces_score_vector = builder.CreateVector(scores);
  TrieBuilder pieces_trie_builder(builder);
  pieces_trie_builder.add_nodes(pieces_trie_vector);
  const auto pieces_trie_fbs = pieces_trie_builder.Finish();

  // Converting normalization.
  const auto normalization =
      DecodePrecompiledCharsmap(model_config.normalizer_spec());
  const auto normalization_trie = std::get<0>(normalization);
  const auto normalization_strings = std::get<1>(normalization);
  const auto normalization_trie_vector =
      builder.CreateVector(normalization_trie);
  TrieBuilder normalization_trie_builder(builder);
  normalization_trie_builder.add_nodes(normalization_trie_vector);
  const auto normalization_trie_fbs = normalization_trie_builder.Finish();
  const auto normalization_strings_fbs =
      builder.CreateVector(normalization_strings);

  EncoderConfigBuilder ecb(builder);
  ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
  ecb.add_start_code(model_config.trainer_spec().bos_id());
  ecb.add_end_code(model_config.trainer_spec().eos_id());
  ecb.add_unknown_code(model_config.trainer_spec().unk_id());
  ecb.add_unknown_penalty(min_score - kUnkPenalty);
  ecb.add_encoding_offset(encoding_offset);
  ecb.add_pieces(pieces_trie_fbs);
  ecb.add_pieces_scores(pieces_score_vector);
  ecb.add_remove_extra_whitespaces(
      model_config.normalizer_spec().remove_extra_whitespaces());
  ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix());
  ecb.add_escape_whitespaces(
      model_config.normalizer_spec().escape_whitespaces());
  ecb.add_normalized_prefixes(normalization_trie_fbs);
  ecb.add_normalized_replacements(normalization_strings_fbs);
  FinishEncoderConfigBuffer(builder, ecb.Finish());
  return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
                     builder.GetSize());
}