def aggregate()

in src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/job/CuboidAggregator.scala [90:274]


  def aggregate(dataset: DataFrame,
                dimensions: util.Set[Integer],
                measures: util.Map[Integer, Measure],
                columnIdFunc: TblColRef => String,
                isSparkSQL: Boolean = false): DataFrame = {
    if (measures.isEmpty) {
      return dataset
        .select(NSparkCubingUtil.getColumns(dimensions): _*)
        .dropDuplicates()
    }

    val reuseLayout = dataset.schema.fieldNames
      .contains(measures.keySet().asScala.head.toString)

    var taggedColIndex: Int = -1

    val agg = measures.asScala.map { measureEntry =>
      val measure = measureEntry._2
      val function = measure.getFunction
      val parameters = function.getParameters.asScala.toList
      val columns = new mutable.ListBuffer[Column]
      val returnType = function.getReturnDataType
      if (parameters.head.isColumnType) {
        if (reuseLayout) {
          columns.append(col(measureEntry._1.toString))
        } else {
          columns.appendAll(parameters.map(p => col(columnIdFunc.apply(p.getColRef))))
        }
      } else {
        if (reuseLayout) {
          columns.append(col(measureEntry._1.toString))
        } else {
          val par = parameters.head.getValue
          if (function.getExpression.equalsIgnoreCase("SUM")) {
            columns.append(lit(par).cast(SparderTypeUtil.toSparkType(returnType)))
          } else {
            columns.append(lit(par))
          }
        }
      }

      function.getExpression.toUpperCase(Locale.ROOT) match {
        case "MAX" =>
          max(columns.head).as(measureEntry._1.toString)
        case "MIN" =>
          min(columns.head).as(measureEntry._1.toString)
        case "SUM" =>
          sum(columns.head).as(measureEntry._1.toString)
        case "COUNT" =>
          if (reuseLayout) {
            sum(columns.head).as(measureEntry._1.toString)
          } else {
            count(columns.head).as(measureEntry._1.toString)
          }
        case "COUNT_DISTINCT" =>
          // add for test
          if (isSparkSQL) {
            countDistinct(columns.head).as(measureEntry._1.toString)
          } else {
            var cdCol = columns.head
            val isBitmap = returnType.getName.equalsIgnoreCase(BitmapMeasureType.DATATYPE_BITMAP)
            val isHllc = returnType.getName.startsWith(HLLCMeasureType.DATATYPE_HLLC)

            if (isBitmap && parameters.size == 2) {
              require(measures.size() == 1, "Opt intersect count can only has one measure.")
              if (!reuseLayout) {
                taggedColIndex = columnIdFunc.apply(parameters.last.getColRef).toInt
                val tagCol = col(taggedColIndex.toString)
                val separator = KapConfig.getInstanceFromEnv.getIntersectCountSeparator
                cdCol = wrapEncodeColumn(columns.head)
                new Column(OptIntersectCount(cdCol.expr, split(tagCol, s"\\$separator").expr).toAggregateExpression())
                  .as(s"map_${measureEntry._1.toString}")
              } else {
                new Column(ReusePreciseCountDistinct(cdCol.expr).toAggregateExpression())
                  .as(measureEntry._1.toString)
              }
            } else {
              if (!reuseLayout) {
                if (isBitmap) {
                  cdCol = wrapEncodeColumn(columns.head)
                  new Column(EncodePreciseCountDistinct(cdCol.expr).toAggregateExpression())
                    .as(measureEntry._1.toString)
                } else if (columns.length > 1 && isHllc) {
                  cdCol = wrapMutilHllcColumn(columns: _*)
                  new Column(EncodeApproxCountDistinct(cdCol.expr, returnType.getPrecision).toAggregateExpression())
                    .as(measureEntry._1.toString)
                } else {
                  new Column(EncodeApproxCountDistinct(cdCol.expr, returnType.getPrecision).toAggregateExpression())
                    .as(measureEntry._1.toString)
                }
              } else {
                if (isBitmap) {
                  new Column(ReusePreciseCountDistinct(cdCol.expr).toAggregateExpression())
                    .as(measureEntry._1.toString)
                } else if (columns.length > 1 && isHllc) {
                  cdCol = wrapMutilHllcColumn(columns: _*)
                  new Column(ReuseApproxCountDistinct(cdCol.expr, returnType.getPrecision).toAggregateExpression())
                    .as(measureEntry._1.toString)
                } else {
                  new Column(ReuseApproxCountDistinct(cdCol.expr, returnType.getPrecision).toAggregateExpression())
                    .as(measureEntry._1.toString)
                }
              }
            }
          }
        case "TOP_N" =>

          val measure = function.getParameters.get(0).getColRef.getColumnDesc

          val schema = StructType(parameters.map(_.getColRef.getColumnDesc).map { col =>
            val dateType = toSparkType(col.getType)
            if (col == measure) {
              StructField(s"MEASURE_${col.getName}", dateType)
            } else {
              StructField(s"DIMENSION_${col.getName}", dateType)
            }
          })

          if (reuseLayout) {
            new Column(ReuseTopN(returnType.getPrecision, schema, columns.head.expr)
              .toAggregateExpression()).as(measureEntry._1.toString)
          } else {
            new Column(EncodeTopN(returnType.getPrecision, schema, columns.head.expr, columns.drop(1).map(_.expr))
              .toAggregateExpression()).as(measureEntry._1.toString)
          }
        case "PERCENTILE_APPROX" =>
          new Column(Percentile(columns.head.expr, returnType.getPrecision)
            .toAggregateExpression()).as(measureEntry._1.toString)
        case "COLLECT_SET" =>
          if (reuseLayout) {
            array_distinct(flatten(collect_set(columns.head))).as(measureEntry._1.toString)
          } else {
            collect_set(columns.head).as(measureEntry._1.toString)
          }
        case "CORR" =>
          new Column(Literal(null, DoubleType)).as(measureEntry._1.toString)
        case "SUM_LC" =>
          val colDataType = function.getReturnDataType
          val sparkDataType = toSparkType(colDataType)
          if (reuseLayout) {
            new Column(ReuseSumLC(columns.head.expr, sparkDataType).toAggregateExpression()).as(measureEntry._1.toString)
          } else {
            new Column(EncodeSumLC(columns.head.expr, columns.drop(1).head.expr, sparkDataType)
              .toAggregateExpression()).as(measureEntry._1.toString)
          }
      }
    }.toSeq

    val dim = if (taggedColIndex != -1 && !reuseLayout) {
      val d = new util.HashSet[Integer](dimensions)
      d.remove(taggedColIndex)
      d
    } else {
      dimensions
    }

    val df: DataFrame = if (!dim.isEmpty) {
      dataset
        .groupBy(NSparkCubingUtil.getColumns(dim): _*)
        .agg(agg.head, agg.drop(1): _*)
    } else {
      dataset
        .agg(agg.head, agg.drop(1): _*)
    }

    // Avoid sum(decimal) add more precision
    // For example: sum(decimal(19,4)) -> decimal(29,4)  sum(sum(decimal(19,4))) -> decimal(38,4)
    if (reuseLayout) {
      val columns = NSparkCubingUtil.getColumns(dimensions) ++ measureColumns(dataset.schema, measures)
      df.select(columns: _*)
    } else {
      if (taggedColIndex != -1) {
        val icCol = df.schema.fieldNames.filter(_.contains("map")).head
        val fieldsWithoutIc = df.schema.fieldNames.filter(!_.contains(icCol))

        val cdMeasureName = icCol.split("_").last
        val newSchema = fieldsWithoutIc.:+(taggedColIndex.toString).:+(cdMeasureName)

        val exploded = fieldsWithoutIc.map(col).:+(explode(col(icCol)))
        df.select(exploded: _*).toDF(newSchema: _*)
      } else {
        df
      }
    }
  }