Map TupleRewriterNode::GenerateVariableRewrites()

in src/relax/ir/dataflow_expr_rewriter.cc [338:522]


Map<Var, Expr> TupleRewriterNode::GenerateVariableRewrites(const Array<Binding>& bindings) const {
  Map<Var, Expr> rewrites;

  Map<Var, Expr> binding_lookup;

  std::vector<VarInfo> info_vec;

  std::unordered_map<Var, size_t> binding_index_lookup;

  // Initialize a vector of indices, each of which corresponds to a
  // potential match for a tuple element.
  //
  // \param tuple_index_of_current_expr The index for the most recent
  // binding.
  //
  // \param indices An output vector, into which indices will be
  // generated.
  //
  // \returns bool True if the indices could be initialized to a
  // potential match.  False, otherwise.
  auto initialize_indices = [&](size_t tuple_index_of_current_expr,
                                std::vector<size_t>& indices) -> bool {
    if (!info_vec.back().matches[tuple_index_of_current_expr]) {
      return false;
    }

    indices = std::vector<size_t>(patterns.size(), info_vec.size());

    indices[tuple_index_of_current_expr] = info_vec.size() - 1;

    for (size_t i_rev = 0; i_rev < indices.size(); i_rev++) {
      size_t i = indices.size() - i_rev - 1;
      if (indices[i] == info_vec.size() - 1) {
        continue;
      }

      auto binding_index = [&]() -> std::optional<size_t> {
        if (indices[i] == info_vec.size() - 1) {
          return info_vec.size() - 1;
        }

        for (size_t j_rev = 1; j_rev < info_vec.size(); j_rev++) {
          size_t j = info_vec.size() - j_rev - 1;
          if (info_vec[j].matches[i] && !info_vec[j].used &&
              std::all_of(indices.begin() + (j + 1), indices.end(),
                          [j](size_t prev_binding_index) { return j != prev_binding_index; })) {
            return j;
          }
        }

        return std::nullopt;
      }();

      if (binding_index.has_value()) {
        indices[i] = binding_index.value();
      } else {
        return false;
      }
    }

    return true;
  };

  auto decrement_indices = [&](std::vector<size_t>& indices) -> bool {
    ICHECK_EQ(indices.size(), patterns.size());

    // Step 1, find the first index that can be decremented, while
    // still generating a valid set of indices.
    size_t i_forward;
    for (i_forward = 0; i_forward < indices.size(); i_forward++) {
      if (indices[i_forward] == info_vec.size() - 1) {
        continue;
      }

      bool found_valid = false;
      size_t& index = indices[i_forward];
      while (index) {
        index--;
        if (info_vec[index].matches[i_forward] && !info_vec[index].used &&
            std::all_of(
                indices.begin() + (i_forward + 1), indices.end(),
                [index](size_t later_binding_index) { return index != later_binding_index; })) {
          found_valid = true;
          break;
        }
      }
      if (found_valid) {
        break;
      }
    }

    // Step 2, if we reached the end, then all indices were
    // decremented to zero without finding anything.  Return false to
    // indicate that we've reached the end.
    if (i_forward == indices.size()) {
      return false;
    }

    // Step 3, refill all indices that were decremented to zero before from 0 to
    for (size_t i = 0; i < i_forward; i++) {
      size_t i_backward = i_forward - (i + 1);
      if (indices[i_backward] == info_vec.size() - 1) {
        continue;
      }

      auto binding_index = [&]() -> std::optional<size_t> {
        for (size_t j_rev = 1; j_rev < info_vec.size(); j_rev++) {
          size_t j = info_vec.size() - j_rev - 1;
          if (info_vec[j].matches[i_backward] && !info_vec[j].used &&
              std::all_of(indices.begin() + (j + 1), indices.end(),
                          [j](size_t prev_binding_index) { return j != prev_binding_index; })) {
            return j;
          }
        }

        return std::nullopt;
      }();

      if (binding_index.has_value()) {
        indices[i_backward] = binding_index.value();
      } else {
        return false;
      }
    }

    return true;
  };

  for (size_t i_binding = 0; i_binding < bindings.size(); i_binding++) {
    const auto& binding = bindings[i_binding];

    auto expr = GetBoundValue(binding);

    binding_index_lookup[binding->var] = i_binding;

    info_vec.push_back(VarInfo{
        binding->var,
        expr,
        patterns.Map(
            [&](const DFPattern& pat) { return ExtractMatchedExpr(pat, expr, binding_lookup); }),
        std::unordered_set<Var>(),
        false,
    });

    auto new_match = [&]() -> std::optional<std::pair<std::vector<size_t>, std::vector<Expr>>> {
      std::vector<size_t> indices;
      for (size_t i = 0; i < patterns.size(); i++) {
        if (initialize_indices(patterns.size() - i - 1, indices)) {
          do {
            if (auto match = TryMatchByBindingIndex(info_vec, indices)) {
              return std::pair{indices, match.value()};
            }
          } while (decrement_indices(indices));
        }
      }
      return std::nullopt;
    }();

    if (new_match) {
      const auto& [indices, exprs] = new_match.value();
      ICHECK_EQ(indices.size(), exprs.size());
      for (size_t i = 0; i < indices.size(); i++) {
        ICHECK_LT(indices[i], info_vec.size());
        auto& info = info_vec[indices[i]];

        ICHECK(!info.used) << "InternalError: "
                           << "Produced multiple replacements for variable " << info.var;

        rewrites.Set(info.var, exprs[i]);
        binding_lookup.erase(info.var);
        info.used = true;
      }
    } else {
      binding_lookup.Set(binding->var, expr);
    }

    for (const auto& prev_var : FreeVars(expr)) {
      if (auto it = binding_index_lookup.find(prev_var); it != binding_index_lookup.end()) {
        info_vec[it->second].downstream_usage.insert(binding->var);
      }
    }
  }

  return rewrites;
}