CodegenASTNode ExprCodeGenerator::convertVeloxExpressionToCodegenAST()

in velox/experimental/codegen/code_generator/ExprCodeGenerator.cpp [54:263]


CodegenASTNode ExprCodeGenerator::convertVeloxExpressionToCodegenAST(
    const std::shared_ptr<const core::ITypedExpr>& node,
    const RowType& inputRowType,
    std::vector<bool> inputRowNullability) {
  // inputRefExpr in velox has one child referring to the input row that do not
  // corresponds to any node in codegen AST and does not have any useful
  // information.
  if (auto inputRefExpr =
          std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(node)) {
    auto codegenNode = std::make_shared<codegen::InputRefExpr>(
        inputRefExpr->type(),
        inputRefExpr->name(),
        inputRowType.getChildIdx(inputRefExpr->name()));
    if (inputRowNullability.size() != 0) {
      /// TODO extract this onto a separate process
      // set nullability of inputRefExpr from inputRowNullability
      codegenNode->setMaybeNull(inputRowNullability.at(
          inputRowType.getChildIdx(inputRefExpr->name())));
    }
    return codegenNode;
  }

  // convert the children subtrees to codegen AST
  std::vector<CodegenASTNode> codegenInputs;
  std::vector<velox::TypeKind> inputTypes;

  for (auto& input : node->inputs()) {
    auto inputNode = convertVeloxExpressionToCodegenAST(
        input, inputRowType, inputRowNullability);

    codegenInputs.push_back(inputNode);
    inputTypes.push_back(input->type()->kind());
  }

  // Create codegen ASTNode and connect children subtrees
  if (auto constantExpr =
          std::dynamic_pointer_cast<const core::ConstantTypedExpr>(node)) {
    switch (constantExpr->type()->kind()) {
      case TypeKind::BOOLEAN:
        return toCodegenConstantExpr<TypeKind::BOOLEAN>(
            constantExpr->type(), constantExpr->value());
      case TypeKind::TINYINT:
        return toCodegenConstantExpr<TypeKind::TINYINT>(
            constantExpr->type(), constantExpr->value());
      case TypeKind::SMALLINT:
        return toCodegenConstantExpr<TypeKind::SMALLINT>(
            constantExpr->type(), constantExpr->value());
      case TypeKind::INTEGER:
        return toCodegenConstantExpr<TypeKind::INTEGER>(
            constantExpr->type(), constantExpr->value());
      case TypeKind::BIGINT:
        return toCodegenConstantExpr<TypeKind::BIGINT>(
            constantExpr->type(), constantExpr->value());
      case TypeKind::REAL:
        return toCodegenConstantExpr<TypeKind::REAL>(
            constantExpr->type(), constantExpr->value());
      case TypeKind::DOUBLE:
        return toCodegenConstantExpr<TypeKind::DOUBLE>(
            constantExpr->type(), constantExpr->value());
      case TypeKind::VARCHAR:
        return toCodegenConstantExpr<TypeKind::VARCHAR>(
            constantExpr->type(), constantExpr->value());
      default:
        break; // not supported otherwise
    }
  }

  // Cast expression
  if (auto castExpr =
          std::dynamic_pointer_cast<const core::CastTypedExpr>(node)) {
    // TODO: support null on failure
    if (castExpr->nullOnFailure() == false &&
        castExpr->type()->isFixedWidth()) {
      auto udfInformation =
          getCastUDF(castExpr->type()->kind(), castIntByTruncate);
      return std::make_shared<codegen::UDFCallExpr>(
          node->type(), udfInformation, codegenInputs);
    }
  }

  if (auto callExpr =
          std::dynamic_pointer_cast<const core::CallTypedExpr>(node)) {
    // Make row expression
    if (callExpr->name() == "row_constructor") {
      return std::make_shared<codegen::MakeRowExpression>(
          node->type(), codegenInputs);
    }

    // If expression
    if (callExpr->name() == "if") {
      if (codegenInputs.size() == 3) {
        return std::make_shared<codegen::IfExpression>(
            node->type(), codegenInputs[0], codegenInputs[1], codegenInputs[2]);
      }

      if (codegenInputs.size() == 2) {
        return std::make_shared<codegen::IfExpression>(
            node->type(), codegenInputs[0], codegenInputs[1]);
      }
    }

    // Switch expression
    if (callExpr->name() == "switch") {
      return std::make_shared<codegen::SwitchExpression>(
          node->type(), codegenInputs);
    }

    // Symbol arithmetic expressions
    if (useBuiltInArithmetic_) {
      if (callExpr->name() == "plus") {
        return std::make_shared<codegen::AddExpr>(
            node->type(), codegenInputs.at(0), codegenInputs.at(1));
      }

      if (callExpr->name() == "minus") {
        return std::make_shared<codegen::SubtractExpr>(
            node->type(), codegenInputs.at(0), codegenInputs.at(1));
      }

      if (callExpr->name() == "multiply") {
        return std::make_shared<codegen::MultiplyExpr>(
            node->type(), codegenInputs.at(0), codegenInputs.at(1));
      }
    }

    if (useBuiltInLogical_) {
      if (callExpr->name() == "gt") {
        return std::make_shared<codegen::GreaterThan>(
            node->type(), codegenInputs.at(0), codegenInputs.at(1));
      }

      if (callExpr->name() == "gte") {
        return std::make_shared<codegen::GreaterThanEquel>(
            node->type(), codegenInputs.at(0), codegenInputs.at(1));
      }

      if (callExpr->name() == "lt") {
        return std::make_shared<codegen::LessThan>(
            node->type(), codegenInputs.at(0), codegenInputs.at(1));
      }

      if (callExpr->name() == "lte") {
        return std::make_shared<codegen::LessThanEqual>(
            node->type(), codegenInputs.at(0), codegenInputs.at(1));
      }

      if (callExpr->name() == "eq") {
        return std::make_shared<codegen::Equal>(
            node->type(), codegenInputs.at(0), codegenInputs.at(1));
      }

      if (callExpr->name() == "neq") {
        return std::make_shared<codegen::NotEqual>(
            node->type(), codegenInputs.at(0), codegenInputs.at(1));
      }

      if (callExpr->name() == "and") {
        return std::make_shared<codegen::LogicalAnd>(
            codegenInputs.at(0), codegenInputs.at(1));
      }

      if (callExpr->name() == "or") {
        return std::make_shared<codegen::LogicalOr>(
            codegenInputs.at(0), codegenInputs.at(1));
      }
    }

    if (callExpr->name() == ("coalesce")) {
      return std::make_shared<codegen::CoalesceExpr>(
          node->type(), codegenInputs);
    }

    if (callExpr->name() == ("is_null")) {
      return std::make_shared<codegen::IsNullExpr>(
          node->type(), codegenInputs.at(0));
    }

    if (callExpr->name() == ("not")) {
      return std::make_shared<codegen::NotExpr>(
          node->type(), codegenInputs.at(0));
    }

    // UDF calls
    // First check if the name is a supported argument-typed udf
    auto udfInformation =
        udfManager_.getUDFInformationTypedArgs(callExpr->name(), inputTypes);
    if (udfInformation.has_value()) {
      return std::make_shared<codegen::UDFCallExpr>(
          node->type(), *udfInformation, codegenInputs);
    }

    // Check name only udfs
    udfInformation = udfManager_.getUDFInformationUnTypedArgs(callExpr->name());
    if (udfInformation.has_value()) {
      return std::make_shared<codegen::UDFCallExpr>(
          node->type(), *udfInformation, codegenInputs);
    }
  }

  // Another form of "Make Row" expression
  if (auto concatExpr =
          std::dynamic_pointer_cast<const core::ConcatTypedExpr>(node)) {
    return std::make_shared<codegen::MakeRowExpression>(
        node->type(), codegenInputs);
  }

  // Translation not supported
  throw CodegenNotSupported(fmt::format(
      "unsupported conversion from typed expression {}\n", node->toString()));
}