private def getAggRelWithRowConstruct()

in backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala [341:495]


  private def getAggRelWithRowConstruct(
      context: SubstraitContext,
      originalInputAttributes: Seq[Attribute],
      operatorId: Long,
      inputRel: RelNode,
      validation: Boolean): RelNode = {
    val args = context.registeredFunction
    // Create a projection for row construct.
    val exprNodes = new JArrayList[ExpressionNode]()
    groupingExpressions.foreach(
      expr => {
        exprNodes.add(
          ExpressionConverter
            .replaceWithExpressionTransformer(expr, originalInputAttributes)
            .doTransform(args))
      })

    for (aggregateExpression <- aggregateExpressions) {
      val aggFunc = aggregateExpression.aggregateFunction
      val functionInputAttributes = aggFunc.inputAggBufferAttributes
      aggFunc match {
        case _
            if aggregateExpression.mode == Partial => // FIXME: Any difference with the last branch?
          val childNodes = aggFunc.children
            .map(
              ExpressionConverter
                .replaceWithExpressionTransformer(_, originalInputAttributes)
                .doTransform(args)
            )
            .asJava
          exprNodes.addAll(childNodes)

        case _: HyperLogLogPlusPlus if aggFunc.aggBufferAttributes.size != 1 =>
          throw new UnsupportedOperationException("Only one input attribute is expected.")

        case _ @VeloxIntermediateData.Type(veloxTypes: Seq[DataType]) =>
          // The process of handling the inconsistency in column types and order between
          // Spark and Velox is exactly the opposite of applyExtractStruct.
          aggregateExpression.mode match {
            case PartialMerge | Final =>
              val newInputAttributes = new ArrayBuffer[Attribute]()
              val childNodes = new JArrayList[ExpressionNode]()
              val (sparkOrders, sparkTypes) =
                aggFunc.aggBufferAttributes.map(attr => (attr.name, attr.dataType)).unzip
              val veloxOrders = VeloxIntermediateData.veloxIntermediateDataOrder(aggFunc)
              val adjustedOrders = veloxOrders.map(sparkOrders.indexOf(_))
              veloxTypes.zipWithIndex.foreach {
                case (veloxType, idx) =>
                  val sparkType = sparkTypes(adjustedOrders(idx))
                  val attr = functionInputAttributes(adjustedOrders(idx))
                  val aggFuncInputAttrNode = ExpressionConverter
                    .replaceWithExpressionTransformer(attr, originalInputAttributes)
                    .doTransform(args)
                  val expressionNode = if (sparkType != veloxType) {
                    newInputAttributes +=
                      attr.copy(dataType = veloxType)(attr.exprId, attr.qualifier)
                    ExpressionBuilder.makeCast(
                      ConverterUtils.getTypeNode(veloxType, attr.nullable),
                      aggFuncInputAttrNode,
                      SQLConf.get.ansiEnabled)
                  } else {
                    newInputAttributes += attr
                    aggFuncInputAttrNode
                  }
                  childNodes.add(expressionNode)
              }
              exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes, aggFunc))
            case other =>
              throw new UnsupportedOperationException(s"$other is not supported.")
          }

        case _ =>
          val childNodes = functionInputAttributes
            .map(
              ExpressionConverter
                .replaceWithExpressionTransformer(_, originalInputAttributes)
                .doTransform(args)
            )
            .asJava
          exprNodes.addAll(childNodes)
      }
    }

    // Create a project rel.
    val emitStartIndex = originalInputAttributes.size
    val projectRel = if (!validation) {
      RelBuilder.makeProjectRel(inputRel, exprNodes, context, operatorId, emitStartIndex)
    } else {
      // Use a extension node to send the input types through Substrait plan for validation.
      val inputTypeNodeList = originalInputAttributes
        .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
        .asJava
      val extensionNode = ExtensionBuilder.makeAdvancedExtension(
        BackendsApiManager.getTransformerApiInstance.packPBMessage(
          TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
      RelBuilder.makeProjectRel(
        inputRel,
        exprNodes,
        extensionNode,
        context,
        operatorId,
        emitStartIndex)
    }

    // Create aggregation rel.
    val groupingList = new JArrayList[ExpressionNode]()
    var colIdx = 0
    groupingExpressions.foreach {
      _ =>
        groupingList.add(ExpressionBuilder.makeSelection(colIdx))
        colIdx += 1
    }

    val aggFilterList = new JArrayList[ExpressionNode]()
    val aggregateFunctionList = new JArrayList[AggregateFunctionNode]()
    aggregateExpressions.foreach(
      aggExpr => {
        if (aggExpr.filter.isDefined) {
          throw new UnsupportedOperationException("Filter in final aggregation is not supported.")
        } else {
          // The number of filters should be aligned with that of aggregate functions.
          aggFilterList.add(null)
        }

        val aggFunc = aggExpr.aggregateFunction
        val childrenNodes = new JArrayList[ExpressionNode]()
        aggExpr.mode match {
          case PartialMerge | Final =>
            // Only occupies one column due to intermediate results are combined
            // by previous projection.
            childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
            colIdx += 1
          case Partial =>
            aggFunc.children.foreach {
              _ =>
                childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
                colIdx += 1
            }
          case _ =>
            throw new UnsupportedOperationException(
              s"$aggFunc of ${aggExpr.mode.toString} is not supported.")
        }
        addFunctionNode(args, aggFunc, childrenNodes, aggExpr.mode, aggregateFunctionList)
      })

    val extensionNode = getAdvancedExtension()
    RelBuilder.makeAggregateRel(
      projectRel,
      groupingList,
      aggregateFunctionList,
      aggFilterList,
      extensionNode,
      context,
      operatorId)
  }