std::string Compiler::gen_compute_node_string()

in src/core/compile.cpp [1277:1368]


std::string Compiler::gen_compute_node_string(LoopTree::TreeRef ref) const {
  std::stringstream ss;
  const auto &node_ref = lt.node(ref);
  const auto &node = lt.ir.node(node_ref);

  bool is_infix = [&]() {
    switch (node.op()) {
      case Operation::add:
      case Operation::multiply:
      case Operation::subtract:
      case Operation::divide:
        return true;
      default:
        return false;
    }
  }();
  bool is_binary = [&]() {
    switch (node.op()) {
      case Operation::add:
      case Operation::multiply:
      case Operation::subtract:
      case Operation::divide:
      case Operation::max:
        return true;
      default:
        return false;
    }
  }();
  auto op = [&]() {
    switch (node.op()) {
      case Operation::add:
        return "+";
      case Operation::multiply:
        return "*";
      case Operation::subtract:
        return "-";
      case Operation::divide:
        return "/";
      case Operation::max:
        return "max";
      case Operation::exp:
        return "exp";
      case Operation::sqrt:
        return "sqrt";
      case Operation::negate:
        return "-";
      case Operation::reciprocal:
        return "1 / ";
      default:
        ASSERT(0) << "can't emit code for " << dump(node.op());
        return "";
    }
  }();

  ss << gen_access_string(node_ref, ref);
  ss << " = ";

  bool is_reduction = lt.ir.reduction_vars(node_ref).size();
  std::vector<std::string> access_strings;
  if (is_reduction) {
    access_strings.emplace_back(gen_access_string(node_ref, ref));
  }
  for (const auto &inp : node.inputs()) {
    access_strings.emplace_back(gen_access_string(inp, ref));
  }

  if (is_infix) {
    for (const auto &access_string : access_strings) {
      ss << access_string;
      if (&access_string != &access_strings.back()) {
        ss << " " << op << " ";
      }
    }
  } else if (is_binary) {
    std::function<void(int)> nest;
    nest = [&](int i) {
      if (i == access_strings.size() - 1) {
        ss << access_strings.at(i);
        return;
      }
      ss << op << "(" << access_strings.at(i) << ", ";
      nest(i + 1);
      ss << ")";
    };
    nest(0);
  } else {
    ASSERT(access_strings.size() == 1);
    ss << op << "(" << access_strings.at(0) << ")";
  }
  ss << ";";
  return ss.str();
}