in src/core/compile.cpp [1277:1368]
std::string Compiler::gen_compute_node_string(LoopTree::TreeRef ref) const {
std::stringstream ss;
const auto &node_ref = lt.node(ref);
const auto &node = lt.ir.node(node_ref);
bool is_infix = [&]() {
switch (node.op()) {
case Operation::add:
case Operation::multiply:
case Operation::subtract:
case Operation::divide:
return true;
default:
return false;
}
}();
bool is_binary = [&]() {
switch (node.op()) {
case Operation::add:
case Operation::multiply:
case Operation::subtract:
case Operation::divide:
case Operation::max:
return true;
default:
return false;
}
}();
auto op = [&]() {
switch (node.op()) {
case Operation::add:
return "+";
case Operation::multiply:
return "*";
case Operation::subtract:
return "-";
case Operation::divide:
return "/";
case Operation::max:
return "max";
case Operation::exp:
return "exp";
case Operation::sqrt:
return "sqrt";
case Operation::negate:
return "-";
case Operation::reciprocal:
return "1 / ";
default:
ASSERT(0) << "can't emit code for " << dump(node.op());
return "";
}
}();
ss << gen_access_string(node_ref, ref);
ss << " = ";
bool is_reduction = lt.ir.reduction_vars(node_ref).size();
std::vector<std::string> access_strings;
if (is_reduction) {
access_strings.emplace_back(gen_access_string(node_ref, ref));
}
for (const auto &inp : node.inputs()) {
access_strings.emplace_back(gen_access_string(inp, ref));
}
if (is_infix) {
for (const auto &access_string : access_strings) {
ss << access_string;
if (&access_string != &access_strings.back()) {
ss << " " << op << " ";
}
}
} else if (is_binary) {
std::function<void(int)> nest;
nest = [&](int i) {
if (i == access_strings.size() - 1) {
ss << access_strings.at(i);
return;
}
ss << op << "(" << access_strings.at(i) << ", ";
nest(i + 1);
ss << ")";
};
nest(0);
} else {
ASSERT(access_strings.size() == 1);
ss << op << "(" << access_strings.at(0) << ")";
}
ss << ";";
return ss.str();
}