in src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/job/CuboidAggregator.scala [90:274]
def aggregate(dataset: DataFrame,
dimensions: util.Set[Integer],
measures: util.Map[Integer, Measure],
columnIdFunc: TblColRef => String,
isSparkSQL: Boolean = false): DataFrame = {
if (measures.isEmpty) {
return dataset
.select(NSparkCubingUtil.getColumns(dimensions): _*)
.dropDuplicates()
}
val reuseLayout = dataset.schema.fieldNames
.contains(measures.keySet().asScala.head.toString)
var taggedColIndex: Int = -1
val agg = measures.asScala.map { measureEntry =>
val measure = measureEntry._2
val function = measure.getFunction
val parameters = function.getParameters.asScala.toList
val columns = new mutable.ListBuffer[Column]
val returnType = function.getReturnDataType
if (parameters.head.isColumnType) {
if (reuseLayout) {
columns.append(col(measureEntry._1.toString))
} else {
columns.appendAll(parameters.map(p => col(columnIdFunc.apply(p.getColRef))))
}
} else {
if (reuseLayout) {
columns.append(col(measureEntry._1.toString))
} else {
val par = parameters.head.getValue
if (function.getExpression.equalsIgnoreCase("SUM")) {
columns.append(lit(par).cast(SparderTypeUtil.toSparkType(returnType)))
} else {
columns.append(lit(par))
}
}
}
function.getExpression.toUpperCase(Locale.ROOT) match {
case "MAX" =>
max(columns.head).as(measureEntry._1.toString)
case "MIN" =>
min(columns.head).as(measureEntry._1.toString)
case "SUM" =>
sum(columns.head).as(measureEntry._1.toString)
case "COUNT" =>
if (reuseLayout) {
sum(columns.head).as(measureEntry._1.toString)
} else {
count(columns.head).as(measureEntry._1.toString)
}
case "COUNT_DISTINCT" =>
// add for test
if (isSparkSQL) {
countDistinct(columns.head).as(measureEntry._1.toString)
} else {
var cdCol = columns.head
val isBitmap = returnType.getName.equalsIgnoreCase(BitmapMeasureType.DATATYPE_BITMAP)
val isHllc = returnType.getName.startsWith(HLLCMeasureType.DATATYPE_HLLC)
if (isBitmap && parameters.size == 2) {
require(measures.size() == 1, "Opt intersect count can only has one measure.")
if (!reuseLayout) {
taggedColIndex = columnIdFunc.apply(parameters.last.getColRef).toInt
val tagCol = col(taggedColIndex.toString)
val separator = KapConfig.getInstanceFromEnv.getIntersectCountSeparator
cdCol = wrapEncodeColumn(columns.head)
new Column(OptIntersectCount(cdCol.expr, split(tagCol, s"\\$separator").expr).toAggregateExpression())
.as(s"map_${measureEntry._1.toString}")
} else {
new Column(ReusePreciseCountDistinct(cdCol.expr).toAggregateExpression())
.as(measureEntry._1.toString)
}
} else {
if (!reuseLayout) {
if (isBitmap) {
cdCol = wrapEncodeColumn(columns.head)
new Column(EncodePreciseCountDistinct(cdCol.expr).toAggregateExpression())
.as(measureEntry._1.toString)
} else if (columns.length > 1 && isHllc) {
cdCol = wrapMutilHllcColumn(columns: _*)
new Column(EncodeApproxCountDistinct(cdCol.expr, returnType.getPrecision).toAggregateExpression())
.as(measureEntry._1.toString)
} else {
new Column(EncodeApproxCountDistinct(cdCol.expr, returnType.getPrecision).toAggregateExpression())
.as(measureEntry._1.toString)
}
} else {
if (isBitmap) {
new Column(ReusePreciseCountDistinct(cdCol.expr).toAggregateExpression())
.as(measureEntry._1.toString)
} else if (columns.length > 1 && isHllc) {
cdCol = wrapMutilHllcColumn(columns: _*)
new Column(ReuseApproxCountDistinct(cdCol.expr, returnType.getPrecision).toAggregateExpression())
.as(measureEntry._1.toString)
} else {
new Column(ReuseApproxCountDistinct(cdCol.expr, returnType.getPrecision).toAggregateExpression())
.as(measureEntry._1.toString)
}
}
}
}
case "TOP_N" =>
val measure = function.getParameters.get(0).getColRef.getColumnDesc
val schema = StructType(parameters.map(_.getColRef.getColumnDesc).map { col =>
val dateType = toSparkType(col.getType)
if (col == measure) {
StructField(s"MEASURE_${col.getName}", dateType)
} else {
StructField(s"DIMENSION_${col.getName}", dateType)
}
})
if (reuseLayout) {
new Column(ReuseTopN(returnType.getPrecision, schema, columns.head.expr)
.toAggregateExpression()).as(measureEntry._1.toString)
} else {
new Column(EncodeTopN(returnType.getPrecision, schema, columns.head.expr, columns.drop(1).map(_.expr))
.toAggregateExpression()).as(measureEntry._1.toString)
}
case "PERCENTILE_APPROX" =>
new Column(Percentile(columns.head.expr, returnType.getPrecision)
.toAggregateExpression()).as(measureEntry._1.toString)
case "COLLECT_SET" =>
if (reuseLayout) {
array_distinct(flatten(collect_set(columns.head))).as(measureEntry._1.toString)
} else {
collect_set(columns.head).as(measureEntry._1.toString)
}
case "CORR" =>
new Column(Literal(null, DoubleType)).as(measureEntry._1.toString)
case "SUM_LC" =>
val colDataType = function.getReturnDataType
val sparkDataType = toSparkType(colDataType)
if (reuseLayout) {
new Column(ReuseSumLC(columns.head.expr, sparkDataType).toAggregateExpression()).as(measureEntry._1.toString)
} else {
new Column(EncodeSumLC(columns.head.expr, columns.drop(1).head.expr, sparkDataType)
.toAggregateExpression()).as(measureEntry._1.toString)
}
}
}.toSeq
val dim = if (taggedColIndex != -1 && !reuseLayout) {
val d = new util.HashSet[Integer](dimensions)
d.remove(taggedColIndex)
d
} else {
dimensions
}
val df: DataFrame = if (!dim.isEmpty) {
dataset
.groupBy(NSparkCubingUtil.getColumns(dim): _*)
.agg(agg.head, agg.drop(1): _*)
} else {
dataset
.agg(agg.head, agg.drop(1): _*)
}
// Avoid sum(decimal) add more precision
// For example: sum(decimal(19,4)) -> decimal(29,4) sum(sum(decimal(19,4))) -> decimal(38,4)
if (reuseLayout) {
val columns = NSparkCubingUtil.getColumns(dimensions) ++ measureColumns(dataset.schema, measures)
df.select(columns: _*)
} else {
if (taggedColIndex != -1) {
val icCol = df.schema.fieldNames.filter(_.contains("map")).head
val fieldsWithoutIc = df.schema.fieldNames.filter(!_.contains(icCol))
val cdMeasureName = icCol.split("_").last
val newSchema = fieldsWithoutIc.:+(taggedColIndex.toString).:+(cdMeasureName)
val exploded = fieldsWithoutIc.map(col).:+(explode(col(icCol)))
df.select(exploded: _*).toDF(newSchema: _*)
} else {
df
}
}
}