in lib/bef_converter/mlir_to_bef/mlir_to_bef.cc [300:486]
LogicalResult EntityTable::Collect(mlir::ModuleOp module,
bool collect_attribute_types_and_names) {
auto result = LogicalResult::Success;
std::vector<std::pair<mlir::SymbolRefAttr, mlir::Location>> fn_attrs;
module.walk(
[&](mlir::Operation* op) {
// Ignore the module itself, and a few specific other ops.
if (op == module.getOperation()) return;
// Ignore operations inside compiled modules. Symbol references into the
// compiled modules passes to kernels as a compilation unit attribute.
if (BefCompilationUnits::IsInCompiledModule(op)) return;
// The return op gets special handling, ensure it is at the end of its
// enclosing block.
if (IsReturn(op)) {
if (&op->getBlock()->back() != op) {
op->emitError() << "return op must be at the end of its block";
result = LogicalResult::Failure;
return;
}
// Ignore it, return gets special handling.
return;
}
auto* cur_region = op->getParentRegion();
// Notice the result and argument types of the ops.
for (auto result : op->getResults()) AddType(result.getType());
for (auto operand : op->getOperands()) {
// Verify that the operand is defined inside the current region. We
// don't support references to outer regions.
if (operand.getParentRegion() != cur_region) {
op->emitError()
<< "BEF executor only supports references to kernels within"
<< " the current region";
result = LogicalResult::Failure;
return;
}
}
// We treat functions specially, putting them into the symbol table and
// ignoring their attributes.
if (auto fn = llvm::dyn_cast<mlir::FuncOp>(op)) {
if (IsNativeFunc(fn)) {
AddNativeFunction(fn);
} else {
if (fn.isExternal()) {
fn.emitError() << "external functions are not allowed";
result = LogicalResult::Failure;
return;
}
// Verify that all functions end with a return to catch a common
// error.
auto& last_op = fn.front().back();
if (!IsReturn(&last_op)) {
last_op.emitError() << "all functions need to have a tfrt.return";
result = LogicalResult::Failure;
return;
}
if (IsSyncFunc(fn)) {
llvm::SmallSetVector<mlir::Value, 4> return_operands;
for (auto iter : llvm::enumerate(last_op.getOperands())) {
auto index = iter.index();
const auto& operand = iter.value();
if (operand.isa<mlir::BlockArgument>()) {
last_op.emitError() << "return value " << index
<< " is an argument in a sync function";
result = LogicalResult::Failure;
return;
}
if (!return_operands.insert(operand)) {
last_op.emitError() << "return value " << index
<< " is duplicated in a sync function";
result = LogicalResult::Failure;
return;
}
}
}
auto func_kind = IsSyncFunc(fn) ? FunctionKind::kSyncBEFFunction
: FunctionKind::kBEFFunction;
if (AddFunction(&fn.getBody(), fn.getName(), func_kind) ==
LogicalResult::Failure) {
result = LogicalResult::Failure;
return;
}
}
} else {
AddKernel(op);
// Keep track of any attributes used by this op.
for (auto attr : op->getAttrs()) {
// Skip cost attribute which is not used in runtime execution.
//
// TODO(tfrt-devs): Use attribute interface instead of hardcoding
// here.
if (attr.getName() == "_tfrt_cost") continue;
// Check to make sure that this is a supported attribute, if not,
// reject it.
if (!BefAttrEmitter::IsSupportedAttribute(attr.getValue()) &&
result == LogicalResult::Success) {
op->emitError() << "BEF files cannot encode the '"
<< attr.getName().getValue() << "' attribute";
result = LogicalResult::Failure;
return;
}
// Returns a symbol ref to an executable operation (function that
// needs to be converted to BEF). If the referenced symbol is inside
// the compiled module returns None. All compiled operations will be
// added to the attributes section as compilation units.
auto bef_function_ref = [&]() -> Optional<mlir::SymbolRefAttr> {
auto sym_attr = attr.getValue().dyn_cast<mlir::SymbolRefAttr>();
if (!sym_attr) return llvm::None;
// Check if the referenced symbol is in the compiled module.
auto* module_op = module.getOperation();
auto* sym_op =
mlir::SymbolTable::lookupSymbolIn(module_op, sym_attr);
if (sym_op && BefCompilationUnits::IsInCompiledModule(sym_op))
return llvm::None;
return sym_attr;
};
if (auto fn_attr = bef_function_ref()) {
// Keep track of function attributes specially so we can diagnose
// them.
fn_attrs.push_back({*fn_attr, op->getLoc()});
} else {
if (collect_attribute_types_and_names) {
// Add attribute names and types for attribute types section and
// attribute names section. These will be ignored by executor.
AddString(attr.getName());
AddAttributeType(attr.getValue());
}
// Skip collecting array of function attributes.
auto array_attr = attr.getValue().dyn_cast<mlir::ArrayAttr>();
if (array_attr) {
if (!array_attr.empty() &&
array_attr.begin()->dyn_cast<mlir::SymbolRefAttr>()) {
continue;
}
}
// We ignore the name of attributes, they just get passed as
// arguments.
attributes.insert(attr.getValue());
}
}
// Keep add any regions used by this op as BEF functions.
for (auto& region : op->getRegions()) {
if (AddFunction(®ion, "", FunctionKind::kBEFFunction) ==
LogicalResult::Failure) {
result = LogicalResult::Failure;
return;
}
}
}
});
// If we're successful, check to make sure that all functions that should be
// translated to BEF can be resolved.
if (result == LogicalResult::Success) {
for (auto attr_and_loc : fn_attrs) {
if (GetFunctionNamed(attr_and_loc.first.getRootReference().getValue()) ==
-1) {
mlir::emitError(attr_and_loc.second)
<< "function " << attr_and_loc.first << " not defined";
return LogicalResult::Failure;
}
}
}
return result;
}