LogicalResult EntityTable::Collect()

in lib/bef_converter/mlir_to_bef/mlir_to_bef.cc [300:486]


LogicalResult EntityTable::Collect(mlir::ModuleOp module,
                                   bool collect_attribute_types_and_names) {
  auto result = LogicalResult::Success;

  std::vector<std::pair<mlir::SymbolRefAttr, mlir::Location>> fn_attrs;

  module.walk(
      [&](mlir::Operation* op) {
        // Ignore the module itself, and a few specific other ops.
        if (op == module.getOperation()) return;

        // Ignore operations inside compiled modules. Symbol references into the
        // compiled modules passes to kernels as a compilation unit attribute.
        if (BefCompilationUnits::IsInCompiledModule(op)) return;

        // The return op gets special handling, ensure it is at the end of its
        // enclosing block.
        if (IsReturn(op)) {
          if (&op->getBlock()->back() != op) {
            op->emitError() << "return op must be at the end of its block";
            result = LogicalResult::Failure;
            return;
          }
          // Ignore it, return gets special handling.
          return;
        }

        auto* cur_region = op->getParentRegion();

        // Notice the result and argument types of the ops.
        for (auto result : op->getResults()) AddType(result.getType());

        for (auto operand : op->getOperands()) {
          // Verify that the operand is defined inside the current region.  We
          // don't support references to outer regions.
          if (operand.getParentRegion() != cur_region) {
            op->emitError()
                << "BEF executor only supports references to kernels within"
                << " the current region";
            result = LogicalResult::Failure;
            return;
          }
        }

        // We treat functions specially, putting them into the symbol table and
        // ignoring their attributes.
        if (auto fn = llvm::dyn_cast<mlir::FuncOp>(op)) {
          if (IsNativeFunc(fn)) {
            AddNativeFunction(fn);
          } else {
            if (fn.isExternal()) {
              fn.emitError() << "external functions are not allowed";
              result = LogicalResult::Failure;
              return;
            }

            // Verify that all functions end with a return to catch a common
            // error.
            auto& last_op = fn.front().back();
            if (!IsReturn(&last_op)) {
              last_op.emitError() << "all functions need to have a tfrt.return";
              result = LogicalResult::Failure;
              return;
            }

            if (IsSyncFunc(fn)) {
              llvm::SmallSetVector<mlir::Value, 4> return_operands;
              for (auto iter : llvm::enumerate(last_op.getOperands())) {
                auto index = iter.index();
                const auto& operand = iter.value();
                if (operand.isa<mlir::BlockArgument>()) {
                  last_op.emitError() << "return value " << index
                                      << " is an argument in a sync function";
                  result = LogicalResult::Failure;
                  return;
                }

                if (!return_operands.insert(operand)) {
                  last_op.emitError() << "return value " << index
                                      << " is duplicated in a sync function";
                  result = LogicalResult::Failure;
                  return;
                }
              }
            }

            auto func_kind = IsSyncFunc(fn) ? FunctionKind::kSyncBEFFunction
                                            : FunctionKind::kBEFFunction;
            if (AddFunction(&fn.getBody(), fn.getName(), func_kind) ==
                LogicalResult::Failure) {
              result = LogicalResult::Failure;
              return;
            }
          }
        } else {
          AddKernel(op);

          // Keep track of any attributes used by this op.
          for (auto attr : op->getAttrs()) {
            // Skip cost attribute which is not used in runtime execution.
            //
            // TODO(tfrt-devs): Use attribute interface instead of hardcoding
            // here.
            if (attr.getName() == "_tfrt_cost") continue;

            // Check to make sure that this is a supported attribute, if not,
            // reject it.
            if (!BefAttrEmitter::IsSupportedAttribute(attr.getValue()) &&
                result == LogicalResult::Success) {
              op->emitError() << "BEF files cannot encode the '"
                              << attr.getName().getValue() << "' attribute";
              result = LogicalResult::Failure;
              return;
            }

            // Returns a symbol ref to an executable operation (function that
            // needs to be converted to BEF). If the referenced symbol is inside
            // the compiled module returns None. All compiled operations will be
            // added to the attributes section as compilation units.
            auto bef_function_ref = [&]() -> Optional<mlir::SymbolRefAttr> {
              auto sym_attr = attr.getValue().dyn_cast<mlir::SymbolRefAttr>();
              if (!sym_attr) return llvm::None;

              // Check if the referenced symbol is in the compiled module.
              auto* module_op = module.getOperation();
              auto* sym_op =
                  mlir::SymbolTable::lookupSymbolIn(module_op, sym_attr);
              if (sym_op && BefCompilationUnits::IsInCompiledModule(sym_op))
                return llvm::None;

              return sym_attr;
            };

            if (auto fn_attr = bef_function_ref()) {
              // Keep track of function attributes specially so we can diagnose
              // them.
              fn_attrs.push_back({*fn_attr, op->getLoc()});

            } else {
              if (collect_attribute_types_and_names) {
                // Add attribute names and types for attribute types section and
                // attribute names section. These will be ignored by executor.
                AddString(attr.getName());
                AddAttributeType(attr.getValue());
              }

              // Skip collecting array of function attributes.
              auto array_attr = attr.getValue().dyn_cast<mlir::ArrayAttr>();
              if (array_attr) {
                if (!array_attr.empty() &&
                    array_attr.begin()->dyn_cast<mlir::SymbolRefAttr>()) {
                  continue;
                }
              }

              // We ignore the name of attributes, they just get passed as
              // arguments.
              attributes.insert(attr.getValue());
            }
          }

          // Keep add any regions used by this op as BEF functions.
          for (auto& region : op->getRegions()) {
            if (AddFunction(&region, "", FunctionKind::kBEFFunction) ==
                LogicalResult::Failure) {
              result = LogicalResult::Failure;
              return;
            }
          }
        }
      });

  // If we're successful, check to make sure that all functions that should be
  // translated to BEF can be resolved.
  if (result == LogicalResult::Success) {
    for (auto attr_and_loc : fn_attrs) {
      if (GetFunctionNamed(attr_and_loc.first.getRootReference().getValue()) ==
          -1) {
        mlir::emitError(attr_and_loc.second)
            << "function " << attr_and_loc.first << " not defined";
        return LogicalResult::Failure;
      }
    }
  }

  return result;
}