Compiler::Access Compiler::gen_access()

in src/core/compile.cpp [1955:2086]


Compiler::Access Compiler::gen_access(IR::NodeRef node_ref,
                                      LoopTree::TreeRef ref) const {
  auto read_node_ref = resolved_reads.at(node_ref);
  auto view_exprs = gen_index_equations(read_node_ref, node_ref, ref);
  for (const auto &e : view_exprs.second) {
    ASSERT(e.type() == Expr::Type::symbol) << "Viewed writes not yet supported";
  }

  // auto view_exprs = gen_constraints(node_ref, ref);

  // This is the robust way to calculate view-based accesses, but is currently a
  // WIP
  // TODO, integrate more fully and simplify gen_access
  // if (base_node_ref != node_ref) {
  //  auto hotfix_exprs = gen_index_equations(base_node_ref, node_ref, ref);
  //  view_exprs.clear();
  //  for (auto i = 0; i < hotfix_exprs.first.size(); ++i) {
  //    auto v = lt.ir.node(base_node_ref).vars().at(i);
  //    auto sym = var_to_sym.at(v);
  //    // later logic can't process this yet
  //    if (hotfix_exprs.first.at(i) == Expr(sym)) {
  //      continue;
  //    }
  //    view_exprs.emplace_back(sym, hotfix_exprs.first.at(i));
  //  }
  //}
  const auto &read_node = lt.ir.node(read_node_ref);
  auto alloc = allocations.at(read_node_ref);

  auto use_node_ref = lt.node(ref);
  const auto &use_node = lt.ir.node(use_node_ref);

  auto node_vars = to_set(lt.ir.all_vars(use_node_ref));
  auto scope_vars = lt.scope_vars(ref);
  auto vars = intersection(node_vars, scope_vars);

  // either input vars
  std::vector<symbolic::Symbol> read_symbols;
  for (auto v : read_node.vars()) {
    if (var_to_sym.count(v)) {
      read_symbols.emplace_back(var_to_sym.at(v));
    }
  }
  auto read_exprs = view_exprs.first;
  for (auto i = 0; i < read_exprs.size(); ++i) {
    const auto &e = read_exprs.at(i);
  }

  auto zero = [&](const symbolic::Expr &expr) {
    auto sized = expr.walk([&](const symbolic::Expr &e) {
                       if (e.op() == symbolic::Op::size) {
                         auto arg = e.args().at(0);
                         if (arg.type() == Expr::Type::symbol) {
                           auto s = var_sizes.at(sym_to_var.at(arg.symbol()));
                           return Expr(s);
                         }
                       }
                       return e;
                     })
                     .simplify();
    return sized
        .walk([&](const symbolic::Expr &e) {
          if (e.type() == Expr::Type::symbol) {
            return Expr(0);
          }
          return e;
        })
        .simplify();
    auto out = expr;
    for (auto s : expr.symbols()) {
      out = out.replace(s, 0).simplify();
    }
    return out;
  };

  ASSERT(alloc.sizes.size() == read_exprs.size());
  auto stride_at = [&](int idx) {
    int64_t stride = alloc.sizes.at(idx) > 0 ? 1 : 0;
    for (auto i = idx + 1; i < alloc.sizes.size(); ++i) {
      auto size = alloc.sizes.at(i);
      stride *= size > 0 ? size : 1;
    }
    return stride;
  };

  Access access(alloc);
  Expr idx_expr(0);
  for (auto i = 0; i < read_exprs.size(); ++i) {
    auto stride = stride_at(i);
    idx_expr = idx_expr + read_exprs.at(i) * Expr(stride);
  }
  access.total_offset = zero(idx_expr).value();

  for (auto v : vars) {
    // NB: ok to generate a fake symbol, we only use it for lookup
    auto sym = var_to_sym.count(v) ? var_to_sym.at(v) : Symbol();
    bool found_expr = false;
    for (auto i = 0; i < read_exprs.size(); ++i) {
      const auto &e = read_exprs.at(i);
      auto read_var = read_node.vars().at(i);
      if (read_symbols.at(i) == sym) {
        auto stride = stride_at(i);
        access.vars[v] = std::make_tuple(stride, 0, -1);
        found_expr = true;
        continue;
      }
      if (!e.contains(sym)) {
        continue;
      }
      ASSERT(!found_expr)
          << "Found two dependencies on the same variable, not yet supported";
      found_expr = true;
      const auto &read_sym = read_symbols.at(i);
      auto constraint = Constraint(read_sym, e);
      auto offset = zero(e).simplify();
      auto expr = isolate(constraint, sym).second;
      auto stride = differentiate(expr, read_symbols.at(i)) *
                    Expr(stride_at(i)).simplify();
      auto max = var_sizes.at(read_var);
      auto v_max = var_sizes.at(v);
      if (max >= v_max) {
        max = -1;
      }
      if (stride.type() != Expr::Type::value ||
          offset.type() != Expr::Type::value) {
        continue;
      }
      access.vars[v] = std::make_tuple(stride.value(), offset.value(), max);
    }
  }
  return access;
}