private def getAggRelWithRowConstruct()

in backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala [321:481]


  private def getAggRelWithRowConstruct(
      context: SubstraitContext,
      originalInputAttributes: Seq[Attribute],
      operatorId: Long,
      inputRel: RelNode,
      validation: Boolean): RelNode = {
    // Create a projection for row construct.
    val exprNodes = new JArrayList[ExpressionNode]()
    groupingExpressions.foreach(
      expr => {
        exprNodes.add(
          ExpressionConverter
            .replaceWithExpressionTransformer(expr, originalInputAttributes)
            .doTransform(context))
      })

    for (aggregateExpression <- aggregateExpressions) {
      val aggFunc = aggregateExpression.aggregateFunction
      val functionInputAttributes = aggFunc.inputAggBufferAttributes
      aggFunc match {
        case _ if aggregateExpression.mode == Partial =>
          val childNodes = aggFunc.children
            .map(
              ExpressionConverter
                .replaceWithExpressionTransformer(_, originalInputAttributes)
                .doTransform(context)
            )
            .asJava
          exprNodes.addAll(childNodes)

        case _: HyperLogLogPlusPlus if aggFunc.aggBufferAttributes.size != 1 =>
          throw new GlutenNotSupportException("Only one input attribute is expected.")

        case _ @VeloxIntermediateData.Type(veloxTypes: Seq[DataType]) =>
          val rewrittenInputAttributes =
            rewriteAggBufferAttributes(functionInputAttributes, originalInputAttributes)
          // The process of handling the inconsistency in column types and order between
          // Spark and Velox is exactly the opposite of applyExtractStruct.
          aggregateExpression.mode match {
            case PartialMerge | Final | Complete =>
              val newInputAttributes = new ArrayBuffer[Attribute]()
              val childNodes = new JArrayList[ExpressionNode]()
              val (sparkOrders, sparkTypes) =
                aggFunc.aggBufferAttributes.map(attr => (attr.name, attr.dataType)).unzip
              val veloxOrders = VeloxIntermediateData.veloxIntermediateDataOrder(aggFunc)
              val adjustedOrders = veloxOrders.map(o => sparkOrders.indexOf(o.head))
              veloxTypes.zipWithIndex.foreach {
                case (veloxType, idx) =>
                  val adjustedIdx = adjustedOrders(idx)
                  if (adjustedIdx == -1) {
                    // The Velox aggregate intermediate buffer column not found in Spark.
                    // For example, skewness and kurtosis share the same aggregate buffer in Velox,
                    // and Kurtosis additionally requires the buffer column of m4, which is
                    // always 0 for skewness. In Spark, the aggregate buffer of skewness does not
                    // have the column of m4, thus a placeholder m4 with a value of 0 must be passed
                    // to Velox, and this value cannot be omitted. Velox will always read m4 column
                    // when accessing the intermediate data.
                    val extraAttr = AttributeReference(veloxOrders(idx).head, veloxType)()
                    newInputAttributes += extraAttr
                    val lt = Literal.default(veloxType)
                    childNodes.add(ExpressionBuilder.makeLiteral(lt.value, lt.dataType, false))
                  } else {
                    val sparkType = sparkTypes(adjustedIdx)
                    val attr = rewrittenInputAttributes(adjustedIdx)
                    val aggFuncInputAttrNode = ExpressionConverter
                      .replaceWithExpressionTransformer(attr, originalInputAttributes)
                      .doTransform(context)
                    val expressionNode = if (sparkType != veloxType) {
                      newInputAttributes +=
                        attr.copy(dataType = veloxType)(attr.exprId, attr.qualifier)
                      ExpressionBuilder.makeCast(
                        ConverterUtils.getTypeNode(veloxType, attr.nullable),
                        aggFuncInputAttrNode,
                        SQLConf.get.ansiEnabled)
                    } else {
                      newInputAttributes += attr
                      aggFuncInputAttrNode
                    }
                    childNodes.add(expressionNode)
                  }
              }
              exprNodes.add(
                getRowConstructNode(context, childNodes, newInputAttributes.toSeq, aggFunc))
            case other =>
              throw new GlutenNotSupportException(s"$other is not supported.")
          }

        case _ =>
          val rewrittenInputAttributes =
            rewriteAggBufferAttributes(functionInputAttributes, originalInputAttributes)
          val childNodes = rewrittenInputAttributes
            .map(
              ExpressionConverter
                .replaceWithExpressionTransformer(_, originalInputAttributes)
                .doTransform(context)
            )
            .asJava
          exprNodes.addAll(childNodes)
      }
    }

    // Create a project rel.
    val projectRel = RelBuilder.makeProjectRel(
      originalInputAttributes.asJava,
      inputRel,
      exprNodes,
      context,
      operatorId,
      validation)

    // Create aggregation rel.
    val groupingList = new JArrayList[ExpressionNode]()
    var colIdx = 0
    groupingExpressions.foreach {
      _ =>
        groupingList.add(ExpressionBuilder.makeSelection(colIdx))
        colIdx += 1
    }

    val aggFilterList = new JArrayList[ExpressionNode]()
    val aggregateFunctionList = new JArrayList[AggregateFunctionNode]()
    aggregateExpressions.foreach(
      aggExpr => {
        if (aggExpr.filter.isDefined) {
          throw new GlutenNotSupportException("Filter in final aggregation is not supported.")
        } else {
          // The number of filters should be aligned with that of aggregate functions.
          aggFilterList.add(null)
        }

        val aggFunc = aggExpr.aggregateFunction
        val childrenNodes = new JArrayList[ExpressionNode]()
        aggExpr.mode match {
          case PartialMerge | Final =>
            // Only occupies one column due to intermediate results are combined
            // by previous projection.
            childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
            colIdx += 1
          case Partial | Complete =>
            aggFunc.children.foreach {
              _ =>
                childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
                colIdx += 1
            }
          case _ =>
            throw new GlutenNotSupportException(
              s"$aggFunc of ${aggExpr.mode.toString} is not supported.")
        }
        addFunctionNode(context, aggFunc, childrenNodes, aggExpr.mode, aggregateFunctionList)
      })

    val extensionNode = getAdvancedExtension()
    RelBuilder.makeAggregateRel(
      projectRel,
      groupingList,
      aggregateFunctionList,
      aggFilterList,
      extensionNode,
      context,
      operatorId)
  }