std::string Compiler::gen_string()

in src/core/compile.cpp [1451:1523]


std::string Compiler::gen_string(
    LoopTree::TreeRef ref,
    std::unordered_map<IR::VarRef, int> overrides) const {
  if (ref == -1) {
    // generate the body first to minimize header code
    std::stringstream body;
    for (auto c : lt.roots) {
      body << gen_string(c);
    }
    std::stringstream ss;
    bool define_max = false;
    for (auto n : lt.ir.nodes()) {
      if (lt.ir.node(n).op() == Operation::max) {
        define_max = true;
      }
    }

    ss << R"""(#include <math.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>

)""";
    if (define_max) {
      ss << R"""(
#define max(a,b) \
   ({ __typeof__ (a) _a = (a); \
       __typeof__ (b) _b = (b); \
     _a > _b ? _a : _b; })
)""";
    }

    if (set_called) {
      ss << R"""(
static inline void set(float* mem, float val, int64_t length) {
  for (int64_t i = 0; i < length; ++i) {
    mem[i] = val;
  }
}
)""";
    }

    ss << "\n";
    const auto &sizes = memory_sizes();
    int last_relevant_input = sizes.size();
    for (int i = sizes.size() - 1; i >= 0; i--) {
      if (sizes.at(i) <= 1) {
        last_relevant_input = i;
      }
    }
    ss << "// memory := { ";
    for (auto i = 0; i < last_relevant_input; ++i) {
      auto s = sizes.at(i);
      if (s <= 1) {
        ss << "nullptr";
      } else {
        ss << "float[" << s << "]";
      }
      if (i + 1 != last_relevant_input) {
        ss << ", ";
      }
    }
    ss << " }\n";
    ss << "void fn_" << count << "(void** memory) {\n";
    ss << body.str();
    ss << "}\n";
    return ss.str();
  }
  if (lt.kind(ref) == LoopTree::NODE) {
    return gen_node_string(ref);
  }
  return gen_loop_string(ref, overrides);
}