Jit/lir/parser.cpp (490 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
#include "Jit/lir/parser.h"
#include "Jit/codegen/code_section.h"
#include "Jit/codegen/x86_64.h"
#include "Jit/lir/operand.h"
#include "Jit/lir/symbol_mapping.h"
#include <algorithm>
#include <cctype>
#include <cstring>
#include <regex>
#include <string>
#include <utility>
namespace jit {
namespace lir {
std::unordered_set<std::string>& GetStringLiterals() {
static std::unordered_set<std::string> string_literals_;
return string_literals_;
}
Parser::Token Parser::getNextToken(const char* str) {
struct PatternType {
PatternType(const char* pattern, TokenType t)
: re(std::string("^") + pattern), type(t) {}
std::regex re;
TokenType type;
};
static const std::vector<PatternType> patterns{
{"Function:.*\n", kFunctionStart},
{"BB %(\\d+)( - .*)?\n", kBasicBlockStart},
{"\n", kNewLine},
{"%(\\d+)", kVReg},
{"R[0-9A-Z]+", kPhyReg},
{"\\[RBP[ ]?-[ ]?(\\d+)\\]", kStack},
{"\\[(0x[0-9a-fA-F]+)\\]", kAddress},
{"(\\d+)(\\(0x[0-9a-fA-F]+\\))?", kImmediate},
{"BB%(\\d+)", kBasicBlockRef},
{"[A-Za-z_][A-Za-z0-9_]+", kId},
{"=", kEqual},
{",", kComma},
{"\\(", kParLeft},
{"\\)", kParRight},
{"#.*\n", kComment},
{":[A-Za-z0-9]+", kDataType},
{"\\[[^\\]]*\\]", kIndirect},
{"\"[^\"]+\"", kStringLiteral}};
std::cmatch m;
for (auto& pattern : patterns) {
if (!std::regex_search(str, m, pattern.re)) {
continue;
}
if (m.size() > 1) {
return {pattern.type, m.length(), strtoll(m.str(1).c_str(), NULL, 0)};
}
return {pattern.type, m.length()};
}
return {kError};
}
// Throw exception if condition is false.
static void expect(bool cond, const char* cur, const char* msg = "") {
if (cond) {
return;
}
JIT_LOG("Unable to parse - %s", msg);
if (strlen(cur) > 64) {
std::string m(cur, cur + 64);
m += "...";
JIT_LOG("String from %s", m);
} else {
JIT_LOG("Starting from %s", cur);
}
throw ParserException(fmt::format("Unable to parse - %s", msg));
}
// Look up an item in the given map. Throw exception if doesn't exist.
template <typename Exc, typename M, typename K>
static auto& map_get_throw(M& map, const K& key) {
auto it = map.find(key);
if (it == map.end()) {
throw Exc("Unable to parse - key not in map");
}
return it->second;
}
std::unique_ptr<Function> Parser::parse(const std::string& code) {
enum {
FUNCTION,
BASIC_BLOCK,
INSTR_OUTPUT,
INSTR_OUTPUT_TYPE,
INSTR_EQUAL,
INSTR_NAME,
INSTR_INPUT,
INSTR_INPUT_TYPE,
INSTR_INPUT_COMMA,
PHI_INPUT_FIRST,
PHI_INPUT_COMMA,
PHI_INPUT_SECOND,
PHI_INPUT_SECOND_TYPE,
PHI_INPUT_PAR,
} state = FUNCTION;
std::unique_ptr<Function> func;
const char* codestr = code.c_str();
const char* cur = codestr;
const char* end = codestr + code.size();
while (cur != end) {
auto token = getNextToken(cur);
auto type = token.type;
while (true) {
if (token.type == kComment) {
// skip comments for now
break;
}
switch (state) {
case FUNCTION: {
// expect a function start
if (type == kNewLine) {
break;
}
expect(type == kFunctionStart, cur, "Expect a function start.");
func = std::make_unique<Function>();
func_ = func.get();
state = BASIC_BLOCK;
break;
}
case BASIC_BLOCK: {
// expect a basic block start
if (type == kNewLine) {
break;
}
expect(type == kBasicBlockStart, cur, "Expect a basic block start.");
int id = token.data;
block_ = func_->allocateBasicBlock();
block_->setId(id);
auto pair = block_index_map_.emplace(id, block_);
expect(pair.second, cur, "Duplicated basic block id.");
setSection(std::string(cur, token.length), block_);
setSuccessorBlocks(std::string(cur, token.length), block_);
state = INSTR_OUTPUT;
break;
}
case INSTR_OUTPUT: {
if (type == kNewLine) {
break;
} else if (type == kBasicBlockStart) {
state = BASIC_BLOCK;
continue;
}
instr_ = block_->allocateInstr(Instruction::kNone, nullptr);
instr_->setId(-1);
auto output = instr_->output();
if (type == kId) {
state = INSTR_NAME;
continue;
} else if (type == kVReg) {
output->setVirtualRegister();
auto pair = output_index_map_.emplace(token.data, instr_);
instr_->setId(token.data);
expect(pair.second, cur, "Duplicated output virtual register.");
} else if (type == kPhyReg) {
output->setPhyRegister(jit::codegen::PhyLocation::parse(
std::string(cur, token.length)));
} else if (type == kStack) {
output->setStackSlot(token.data);
} else if (type == kAddress) {
output->setMemoryAddress(reinterpret_cast<void*>(token.data));
} else if (type == kImmediate) {
output->setConstant(token.data);
} else if (type == kIndirect) {
parseIndirect(output, std::string_view(cur, token.length), cur);
} else {
expect(false, cur);
}
state = INSTR_OUTPUT_TYPE;
break;
}
case INSTR_OUTPUT_TYPE: {
if (type == kEqual) {
state = INSTR_EQUAL;
continue;
}
expect(type == kDataType, cur, "Expect output data type.");
instr_->output()->setDataType(
getOperandDataType(std::string(cur, token.length)));
state = INSTR_EQUAL;
break;
}
case INSTR_EQUAL: {
expect(type == kEqual, cur, "Expect \"=\".");
state = INSTR_NAME;
break;
}
case INSTR_NAME: {
expect(type == kId, cur, "Expect an instruction name.");
instr_->setOpcode(getInstrOpcode(std::string(cur, token.length)));
state = INSTR_INPUT;
break;
}
case INSTR_INPUT: {
if (type == kNewLine) {
state = INSTR_OUTPUT;
break;
}
if (type == kParLeft) {
state = PHI_INPUT_FIRST;
} else {
parseInput(token, cur);
state = INSTR_INPUT_TYPE;
}
break;
}
case INSTR_INPUT_TYPE: {
if (type == kComma || type == kNewLine) {
state = INSTR_INPUT_COMMA;
continue;
}
expect(type == kDataType, cur, "Expect input data type.");
expect(
instr_->getNumInputs() > 0,
cur,
"Expect data type to follow an input.");
OperandBase* input_base =
instr_->getInput(instr_->getNumInputs() - 1);
if (!input_base->isLinked()) {
Operand* input = static_cast<Operand*>(input_base);
auto data_type = getOperandDataType(std::string(cur, token.length));
input->setDataType(data_type);
}
state = INSTR_INPUT_COMMA;
break;
}
case INSTR_INPUT_COMMA: {
// expect commas between inputs
if (type == kNewLine) {
state = INSTR_OUTPUT;
break;
}
expect(type == kComma, cur, "Expect a comma.");
state = INSTR_INPUT;
break;
}
case PHI_INPUT_FIRST: {
// first argument of phi input pairs - basic block id
expect(type == kBasicBlockRef, cur, "Expect a basic block id.");
parseInput(token, cur);
state = PHI_INPUT_COMMA;
break;
}
case PHI_INPUT_COMMA: {
expect(type == kComma, cur, "Expect a comma.");
state = PHI_INPUT_SECOND;
break;
}
case PHI_INPUT_SECOND: {
// second argument of phi input pairs - a variable
parseInput(token, cur);
state = PHI_INPUT_SECOND_TYPE;
break;
}
case PHI_INPUT_SECOND_TYPE: {
if (type == kParRight) {
state = PHI_INPUT_PAR;
continue;
}
expect(type == kDataType, cur, "Expect phi input second data type.");
expect(
instr_->getNumInputs() > 0,
cur,
"Expect data type to follow an input.");
OperandBase* input_base =
instr_->getInput(instr_->getNumInputs() - 1);
if (!input_base->isLinked()) {
Operand* input = static_cast<Operand*>(input_base);
auto data_type = getOperandDataType(std::string(cur, token.length));
input->setDataType(data_type);
}
state = PHI_INPUT_PAR;
break;
}
case PHI_INPUT_PAR: {
// expect a right parenthesis
expect(type == kParRight, cur, "Expect a right parenthesis");
state = INSTR_INPUT_COMMA;
break;
}
}
break;
}
cur += token.length;
// skip whitespaces
while (cur != end && (*cur == ' ' || *cur == '\t')) {
cur++;
}
}
fixOperands();
connectBasicBlocks();
fixUnknownIds();
return func;
}
void Parser::setSection(const std::string& bbdef, BasicBlock* bb) {
std::regex section_re = std::regex("- section: (hot|cold)");
std::cmatch section_m;
if (std::regex_search(bbdef.c_str(), section_m, section_re) &&
section_m.size() > 1) {
std::string section = section_m.str(1);
if (section == "hot") {
bb->setSection(codegen::CodeSection::kHot);
} else {
JIT_CHECK(section == "cold", "Code section must be hot or cold.");
bb->setSection(codegen::CodeSection::kCold);
}
}
}
void Parser::setSuccessorBlocks(const std::string& bbdef, BasicBlock* bb) {
std::regex succ_re = std::regex("- succs: %(\\d+)(?: %(\\d+))?");
std::cmatch succ_m;
if (std::regex_search(bbdef.c_str(), succ_m, succ_re) && succ_m.size() > 1) {
int64_t succ1 = atoll(succ_m.str(1).c_str());
basic_block_succs_.emplace_back(bb, succ1);
if (succ_m.size() > 2 && succ_m.str(2).size() > 0) {
int64_t succ2 = atoll(succ_m.str(2).c_str());
basic_block_succs_.emplace_back(bb, succ2);
}
}
}
OperandBase::DataType Parser::getOperandDataType(
const std::string& name) const {
static const std::unordered_map<std::string, OperandBase::DataType>
type_name_to_data_type = {
#define TYPE_NAME_TO_DATA_TYPE(v, ...) {":" #v, OperandBase::k##v},
FOREACH_OPERAND_DATA_TYPE(TYPE_NAME_TO_DATA_TYPE)
#undef TYPE_NAME_TO_DATA_TYPE
};
return map_get_throw<ParserException>(type_name_to_data_type, name);
}
Instruction::Opcode Parser::getInstrOpcode(const std::string& name) const {
static const std::unordered_map<std::string, Instruction::Opcode>
instr_name_to_opcode = {
#define INSTR_NAME_TO_OPCODE(v, ...) {#v, Instruction::k##v},
FOREACH_INSTR_TYPE(INSTR_NAME_TO_OPCODE)
#undef INSTR_NAME_TO_OPCODE
};
return map_get_throw<ParserException>(instr_name_to_opcode, name);
}
void Parser::parseInput(const Token& token, const char* code) {
auto type = token.type;
switch (type) {
case kVReg: {
auto linked_opnd = instr_->allocateLinkedInput(nullptr);
auto id = token.data;
instr_refs_.emplace(linked_opnd, id);
break;
}
case kPhyReg: {
auto reg =
jit::codegen::PhyLocation::parse(std::string(code, token.length));
expect(
reg != jit::codegen::PhyLocation::REG_INVALID,
code,
"Unable to parse physical register.");
instr_->allocatePhyRegisterInput(reg);
break;
}
case kStack: {
instr_->allocateStackInput(token.data);
break;
}
case kAddress: {
instr_->allocateAddressInput(reinterpret_cast<void*>(token.data));
break;
}
case kImmediate: {
instr_->allocateImmediateInput(token.data);
break;
}
case kBasicBlockRef: {
auto opnd = instr_->allocateImmediateInput(0);
basic_block_refs_.emplace(opnd, token.data);
break;
}
case kIndirect: {
auto opnd = instr_->allocateMemoryIndirectInput(PhyLocation::REG_INVALID);
parseIndirect(opnd, std::string_view(code, token.length), code);
break;
}
case kId: {
uint64_t imm_addr = map_get_throw<ParserException>(
kSymbolMapping, std::string(code, token.length));
instr_->allocateImmediateInput(
reinterpret_cast<uint64_t>(imm_addr), OperandBase::kObject);
break;
}
case kStringLiteral: {
ThreadedCompileSerialize guard;
std::unordered_set<std::string>& v = GetStringLiterals();
auto ret = v.emplace(code, 1, token.length - 2);
instr_->allocateImmediateInput(
reinterpret_cast<uint64_t>((*ret.first).c_str()),
OperandBase::kObject);
break;
}
default:
expect(false, code, "Unable to parse instruction input.");
}
}
void Parser::parseIndirect(
Operand* opnd,
std::string_view token,
const char* code) {
std::variant<Instruction*, PhyLocation> base =
jit::codegen::PhyLocation::REG_INVALID;
std::variant<Instruction*, PhyLocation> index = nullptr;
uint8_t multiplier = 0;
int32_t offset = 0;
std::cmatch m;
// keep track of length of parsed operand
// start at 1 to account for the right bracket
size_t expected_length = 1;
// parse base register
std::regex base_reg = std::regex("\\[%(\\d+):[0-9a-zA-Z]+");
std::regex base_phys = std::regex("\\[(R[0-9A-Z]+):Object");
if (std::regex_search(token.begin(), token.end(), m, base_reg)) {
auto base_id = std::stoll(m.str(1).c_str(), nullptr, 0);
base = map_get_throw<ParserException>(output_index_map_, base_id);
expected_length += m.length();
} else if (std::regex_search(token.begin(), token.end(), m, base_phys)) {
base = jit::codegen::PhyLocation::parse(m.str(1));
expected_length += m.length();
} else {
expect(false, code, "Expected a base register.");
}
// parse index and multiplier
std::regex index_reg = std::regex("\\+ %(\\d+):[0-9a-zA-Z]+( \\* (\\d+))?");
std::regex index_phys = std::regex("\\+ (R[0-9A-Z]+):Object( \\* (\\d+))?");
bool index_re_success = false;
if (std::regex_search(token.begin(), token.end(), m, index_reg)) {
auto index_id = std::stoll(m.str(1).c_str(), nullptr, 0);
index = map_get_throw<ParserException>(output_index_map_, index_id);
index_re_success = true;
// add 1 for space between base and index operands
expected_length += m.length() + 1;
} else if (std::regex_search(token.begin(), token.end(), m, index_phys)) {
index = jit::codegen::PhyLocation::parse(m.str(1));
index_re_success = true;
expected_length += m.length() + 1;
}
if (index_re_success && m.size() > 3 && m.str(3).size() > 0) {
int64_t exp_multiplier = std::stoll(m.str(3).c_str(), nullptr, 0);
expect(
exp_multiplier != 0 && (exp_multiplier & (exp_multiplier - 1)) == 0,
code,
"The multiplier should not be zero and must be integral power of 2.");
multiplier = __builtin_ctzll(exp_multiplier);
}
// parse offset
std::regex offset_re = std::regex("([\\+-]) (0x[0-9a-fA-F]+)");
if (std::regex_search(token.begin(), token.end(), m, offset_re)) {
// need to remove space between sign and hex for stoll conversion
offset = std::stoll((m.str(1) + m.str(2)).c_str(), nullptr, 0);
expected_length += m.length() + 1;
}
expect(
expected_length == token.length(),
code,
"Unable to parse memory indirect operand.");
opnd->setMemoryIndirect(base, index, multiplier, offset);
}
void Parser::fixOperands() {
for (auto& pair : basic_block_refs_) {
auto operand = pair.first;
int block_index = pair.second;
operand->setBasicBlock(
map_get_throw<ParserException>(block_index_map_, block_index));
}
for (auto& pair : instr_refs_) {
auto operand = pair.first;
int instr_index = pair.second;
auto instr = map_get_throw<ParserException>(output_index_map_, instr_index);
instr->output()->addUse(operand);
}
}
void Parser::connectBasicBlocks() {
// Note - Order of successors matters.
// It depends on the order in which we add pairs to basic_block_succs_
for (auto& succ_pair : basic_block_succs_) {
BasicBlock* source_block = succ_pair.first;
int dest_block_id = succ_pair.second;
source_block->addSuccessor(
map_get_throw<ParserException>(block_index_map_, dest_block_id));
}
}
void Parser::fixUnknownIds() {
// find largest ID
int largest_id = -1;
for (auto& bb : func_->basicblocks()) {
if (bb->id() > largest_id) {
largest_id = bb->id();
}
for (auto& instr : bb->instructions()) {
if (instr->id() > largest_id) {
largest_id = instr->id();
}
}
}
func_->setNextId(largest_id + 1);
// all basic blocks should have been assigned an ID
// assign ID's to instructions without ID's
for (auto& bb : func_->basicblocks()) {
for (auto& instr : bb->instructions()) {
if (instr->id() == -1) {
instr->setId(func_->allocateId());
}
}
}
}
} // namespace lir
} // namespace jit