rocksdb::Status String::LCS()

in src/types/redis_string.cc [508:658]


rocksdb::Status String::LCS(engine::Context &ctx, const std::string &user_key1, const std::string &user_key2,
                            StringLCSArgs args, StringLCSResult *rst) {
  if (args.type == StringLCSType::LEN) {
    *rst = static_cast<uint32_t>(0);
  } else if (args.type == StringLCSType::IDX) {
    *rst = StringLCSIdxResult{{}, 0};
  } else {
    *rst = std::string{};
  }

  std::string a;
  std::string b;
  std::string ns_key1 = AppendNamespacePrefix(user_key1);
  std::string ns_key2 = AppendNamespacePrefix(user_key2);
  auto s1 = getValue(ctx, ns_key1, &a);
  auto s2 = getValue(ctx, ns_key2, &b);

  if (!s1.ok() && !s1.IsNotFound()) {
    return s1;
  }
  if (!s2.ok() && !s2.IsNotFound()) {
    return s2;
  }
  if (s1.IsNotFound()) a = "";
  if (s2.IsNotFound()) b = "";

  // Detect string truncation or later overflows.
  if (a.length() >= UINT32_MAX - 1 || b.length() >= UINT32_MAX - 1) {
    return rocksdb::Status::InvalidArgument("String too long for LCS");
  }

  // Compute the LCS using the vanilla dynamic programming technique of
  // building a table of LCS(x, y) substrings.
  auto alen = static_cast<uint32_t>(a.length());
  auto blen = static_cast<uint32_t>(b.length());

  // Allocate the LCS table.
  uint64_t dp_size = (alen + 1) * (blen + 1);
  uint64_t bulk_size = dp_size * sizeof(uint32_t);
  if (bulk_size > storage_->GetConfig()->proto_max_bulk_len || bulk_size / dp_size != sizeof(uint32_t)) {
    return rocksdb::Status::Aborted("Insufficient memory, transient memory for LCS exceeds proto-max-bulk-len");
  }
  std::vector<uint32_t> dp(dp_size, 0);
  auto lcs = [&dp, blen](const uint32_t i, const uint32_t j) -> uint32_t & { return dp[i * (blen + 1) + j]; };

  // Start building the LCS table.
  for (uint32_t i = 1; i <= alen; i++) {
    for (uint32_t j = 1; j <= blen; j++) {
      if (a[i - 1] == b[j - 1]) {
        // The len LCS (and the LCS itself) of two
        // sequences with the same final character, is the
        // LCS of the two sequences without the last char
        // plus that last char.
        lcs(i, j) = lcs(i - 1, j - 1) + 1;
      } else {
        // If the last character is different, take the longest
        // between the LCS of the first string and the second
        // minus the last char, and the reverse.
        lcs(i, j) = std::max(lcs(i - 1, j), lcs(i, j - 1));
      }
    }
  }

  uint32_t idx = lcs(alen, blen);

  // Only compute the length of LCS.
  if (auto result = std::get_if<uint32_t>(rst)) {
    *result = idx;
    return rocksdb::Status::OK();
  }

  // Store the length of the LCS first if needed.
  if (auto result = std::get_if<StringLCSIdxResult>(rst)) {
    result->len = idx;
  }

  // Allocate when we need to compute the actual LCS string.
  if (auto result = std::get_if<std::string>(rst)) {
    result->resize(idx);
  }

  uint32_t i = alen;
  uint32_t j = blen;
  uint32_t a_range_start = alen;  // alen signals that values are not set.
  uint32_t a_range_end = 0;
  uint32_t b_range_start = 0;
  uint32_t b_range_end = 0;
  while (i > 0 && j > 0) {
    bool emit_range = false;
    if (a[i - 1] == b[j - 1]) {
      // If there is a match, store the character if needed.
      // And reduce the indexes to look for a new match.
      if (auto result = std::get_if<std::string>(rst)) {
        result->at(idx - 1) = a[i - 1];
      }

      // Track the current range.
      if (a_range_start == alen) {
        a_range_start = i - 1;
        a_range_end = i - 1;
        b_range_start = j - 1;
        b_range_end = j - 1;
      }
      // Let's see if we can extend the range backward since
      // it is contiguous.
      else if (a_range_start == i && b_range_start == j) {
        a_range_start--;
        b_range_start--;
      } else {
        emit_range = true;
      }

      // Emit the range if we matched with the first byte of
      // one of the two strings. We'll exit the loop ASAP.
      if (a_range_start == 0 || b_range_start == 0) {
        emit_range = true;
      }
      idx--;
      i--;
      j--;
    } else {
      // Otherwise reduce i and j depending on the largest
      // LCS between, to understand what direction we need to go.
      uint32_t lcs1 = lcs(i - 1, j);
      uint32_t lcs2 = lcs(i, j - 1);
      if (lcs1 > lcs2)
        i--;
      else
        j--;
      if (a_range_start != alen) emit_range = true;
    }

    // Emit the current range if needed.
    if (emit_range) {
      if (auto result = std::get_if<StringLCSIdxResult>(rst)) {
        uint32_t match_len = a_range_end - a_range_start + 1;

        // Always emit the range when the `min_match_len` is not set.
        if (args.min_match_len == 0 || match_len >= args.min_match_len) {
          result->matches.emplace_back(StringLCSRange{a_range_start, a_range_end},
                                       StringLCSRange{b_range_start, b_range_end}, match_len);
        }
      }

      // Restart at the next match.
      a_range_start = alen;
    }
  }

  return rocksdb::Status::OK();
}