in src/core/compile.cpp [1955:2086]
Compiler::Access Compiler::gen_access(IR::NodeRef node_ref,
LoopTree::TreeRef ref) const {
auto read_node_ref = resolved_reads.at(node_ref);
auto view_exprs = gen_index_equations(read_node_ref, node_ref, ref);
for (const auto &e : view_exprs.second) {
ASSERT(e.type() == Expr::Type::symbol) << "Viewed writes not yet supported";
}
// auto view_exprs = gen_constraints(node_ref, ref);
// This is the robust way to calculate view-based accesses, but is currently a
// WIP
// TODO, integrate more fully and simplify gen_access
// if (base_node_ref != node_ref) {
// auto hotfix_exprs = gen_index_equations(base_node_ref, node_ref, ref);
// view_exprs.clear();
// for (auto i = 0; i < hotfix_exprs.first.size(); ++i) {
// auto v = lt.ir.node(base_node_ref).vars().at(i);
// auto sym = var_to_sym.at(v);
// // later logic can't process this yet
// if (hotfix_exprs.first.at(i) == Expr(sym)) {
// continue;
// }
// view_exprs.emplace_back(sym, hotfix_exprs.first.at(i));
// }
//}
const auto &read_node = lt.ir.node(read_node_ref);
auto alloc = allocations.at(read_node_ref);
auto use_node_ref = lt.node(ref);
const auto &use_node = lt.ir.node(use_node_ref);
auto node_vars = to_set(lt.ir.all_vars(use_node_ref));
auto scope_vars = lt.scope_vars(ref);
auto vars = intersection(node_vars, scope_vars);
// either input vars
std::vector<symbolic::Symbol> read_symbols;
for (auto v : read_node.vars()) {
if (var_to_sym.count(v)) {
read_symbols.emplace_back(var_to_sym.at(v));
}
}
auto read_exprs = view_exprs.first;
for (auto i = 0; i < read_exprs.size(); ++i) {
const auto &e = read_exprs.at(i);
}
auto zero = [&](const symbolic::Expr &expr) {
auto sized = expr.walk([&](const symbolic::Expr &e) {
if (e.op() == symbolic::Op::size) {
auto arg = e.args().at(0);
if (arg.type() == Expr::Type::symbol) {
auto s = var_sizes.at(sym_to_var.at(arg.symbol()));
return Expr(s);
}
}
return e;
})
.simplify();
return sized
.walk([&](const symbolic::Expr &e) {
if (e.type() == Expr::Type::symbol) {
return Expr(0);
}
return e;
})
.simplify();
auto out = expr;
for (auto s : expr.symbols()) {
out = out.replace(s, 0).simplify();
}
return out;
};
ASSERT(alloc.sizes.size() == read_exprs.size());
auto stride_at = [&](int idx) {
int64_t stride = alloc.sizes.at(idx) > 0 ? 1 : 0;
for (auto i = idx + 1; i < alloc.sizes.size(); ++i) {
auto size = alloc.sizes.at(i);
stride *= size > 0 ? size : 1;
}
return stride;
};
Access access(alloc);
Expr idx_expr(0);
for (auto i = 0; i < read_exprs.size(); ++i) {
auto stride = stride_at(i);
idx_expr = idx_expr + read_exprs.at(i) * Expr(stride);
}
access.total_offset = zero(idx_expr).value();
for (auto v : vars) {
// NB: ok to generate a fake symbol, we only use it for lookup
auto sym = var_to_sym.count(v) ? var_to_sym.at(v) : Symbol();
bool found_expr = false;
for (auto i = 0; i < read_exprs.size(); ++i) {
const auto &e = read_exprs.at(i);
auto read_var = read_node.vars().at(i);
if (read_symbols.at(i) == sym) {
auto stride = stride_at(i);
access.vars[v] = std::make_tuple(stride, 0, -1);
found_expr = true;
continue;
}
if (!e.contains(sym)) {
continue;
}
ASSERT(!found_expr)
<< "Found two dependencies on the same variable, not yet supported";
found_expr = true;
const auto &read_sym = read_symbols.at(i);
auto constraint = Constraint(read_sym, e);
auto offset = zero(e).simplify();
auto expr = isolate(constraint, sym).second;
auto stride = differentiate(expr, read_symbols.at(i)) *
Expr(stride_at(i)).simplify();
auto max = var_sizes.at(read_var);
auto v_max = var_sizes.at(v);
if (max >= v_max) {
max = -1;
}
if (stride.type() != Expr::Type::value ||
offset.type() != Expr::Type::value) {
continue;
}
access.vars[v] = std::make_tuple(stride.value(), offset.value(), max);
}
}
return access;
}