in src/core/compile.cpp [1451:1523]
std::string Compiler::gen_string(
LoopTree::TreeRef ref,
std::unordered_map<IR::VarRef, int> overrides) const {
if (ref == -1) {
// generate the body first to minimize header code
std::stringstream body;
for (auto c : lt.roots) {
body << gen_string(c);
}
std::stringstream ss;
bool define_max = false;
for (auto n : lt.ir.nodes()) {
if (lt.ir.node(n).op() == Operation::max) {
define_max = true;
}
}
ss << R"""(#include <math.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
)""";
if (define_max) {
ss << R"""(
#define max(a,b) \
({ __typeof__ (a) _a = (a); \
__typeof__ (b) _b = (b); \
_a > _b ? _a : _b; })
)""";
}
if (set_called) {
ss << R"""(
static inline void set(float* mem, float val, int64_t length) {
for (int64_t i = 0; i < length; ++i) {
mem[i] = val;
}
}
)""";
}
ss << "\n";
const auto &sizes = memory_sizes();
int last_relevant_input = sizes.size();
for (int i = sizes.size() - 1; i >= 0; i--) {
if (sizes.at(i) <= 1) {
last_relevant_input = i;
}
}
ss << "// memory := { ";
for (auto i = 0; i < last_relevant_input; ++i) {
auto s = sizes.at(i);
if (s <= 1) {
ss << "nullptr";
} else {
ss << "float[" << s << "]";
}
if (i + 1 != last_relevant_input) {
ss << ", ";
}
}
ss << " }\n";
ss << "void fn_" << count << "(void** memory) {\n";
ss << body.str();
ss << "}\n";
return ss.str();
}
if (lt.kind(ref) == LoopTree::NODE) {
return gen_node_string(ref);
}
return gen_loop_string(ref, overrides);
}