in src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/plan/WindowPlan.scala [58:245]
def window(plan: LogicalPlan,
rel: OlapWindowRel,
datacontex: DataContext): LogicalPlan = {
val start = System.currentTimeMillis()
var windowCount = 0
rel.groups.asScala.head.upperBound
val columnSize = plan.output.size
val columns = plan.output.map(c => col(c.name))
val constantMap = rel.getConstants.asScala
.map(_.getValue)
.zipWithIndex
.map { entry =>
(entry._2 + columnSize, entry._1)
}.toMap[Int, Any]
val visitor = new SparderRexVisitor(plan,
rel.getInput.getRowType,
datacontex)
val constants = rel.getConstants.asScala
.map { constant =>
k_lit(Literal.apply(constant.accept(visitor)))
}
val columnsAndConstants = columns ++ constants
val windows = rel.groups.asScala
.flatMap { group =>
var isDateTimeFamilyType = false
val fieldsNameToType = rel.getInput.getRowType.getFieldList.asScala.zipWithIndex
.map {
case (field, index) => index -> field.getType.getSqlTypeName.toString
}.toMap
fieldsNameToType.foreach(map =>
if (SparderTypeUtil.isDateTimeFamilyType(map._2)) {
isDateTimeFamilyType = true
})
var orderByColumns = group.orderKeys
.asInstanceOf[RelCollationImpl]
.getFieldCollations
.asScala
.map { fieldIndex =>
var column = columns.apply(fieldIndex.getFieldIndex)
if (!group.isRows && fieldsNameToType(fieldIndex.getFieldIndex).equalsIgnoreCase("timestamp")) {
column = column.cast(LongType)
}
fieldIndex.direction match {
case Direction.DESCENDING =>
column = column.desc
case Direction.STRICTLY_DESCENDING =>
column = column.desc
case Direction.ASCENDING =>
column = column.asc
case Direction.STRICTLY_ASCENDING =>
column = column.asc
case _ =>
}
column
}
.toList
val partitionColumns = group.keys.asScala
.map(fieldIndex => columns.apply(fieldIndex))
.toSeq
group.aggCalls.asScala.map { agg =>
var windowDesc: WindowSpec = null
val opName = agg.op.getName.toUpperCase(Locale.ROOT)
val numberConstants = constantMap
.filter(_._2.isInstanceOf[Number])
.map { entry =>
(entry._1, entry._2.asInstanceOf[Number])
}.toMap
var (lowerBound: Long, upperBound: Long) = buildRange(group, numberConstants, isDateTimeFamilyType, group.isRows)
if (orderByColumns.nonEmpty) {
windowDesc = Window.orderBy(orderByColumns: _*)
if (!nonRangeSpecified.contains(opName)) {
if (group.isRows || rowSpecified.contains(opName)) {
windowDesc = windowDesc.rowsBetween(lowerBound, upperBound)
} else {
windowDesc = windowDesc.rangeBetween(lowerBound, upperBound)
}
}
} else {
if (sortSpecified.contains(opName)) {
windowDesc = Window.orderBy(k_lit(1))
if (!nonRangeSpecified.contains(opName)) {
if (group.isRows || rowSpecified.contains(opName)) {
windowDesc = windowDesc.rowsBetween(lowerBound, upperBound)
} else {
windowDesc = windowDesc.rangeBetween(lowerBound, upperBound)
}
}
}
}
if (partitionColumns.nonEmpty) {
windowDesc =
if (windowDesc == null) Window.partitionBy(partitionColumns: _*)
else windowDesc.partitionBy(partitionColumns: _*)
}
val func = opName match {
case "ROW_NUMBER" =>
row_number()
case "RANK" =>
rank()
case "DENSE_RANK" =>
dense_rank()
case "FIRST_VALUE" =>
first(
columnsAndConstants.apply(
agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex))
case "LAST_VALUE" =>
last(
columnsAndConstants.apply(
agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex))
case "LEAD" =>
val args =
agg.operands.asScala.map(_.asInstanceOf[RexInputRef].getIndex)
args.size match {
// offset default value is 1 in spark
case 1 => lead(columnsAndConstants.apply(args.head), 1)
case 2 => lead(columnsAndConstants.apply(args.head),
constantMap.apply(args(1)).asInstanceOf[Number].intValue())
case 3 =>
lead(columnsAndConstants.apply(args.head),
constantMap.apply(args(1)).asInstanceOf[Number].intValue(),
constantValue(rel, constantMap, args(2), visitor))
}
case "LAG" =>
val args =
agg.operands.asScala.map(_.asInstanceOf[RexInputRef].getIndex)
args.size match {
// offset default value is 1 in spark
case 1 => lag(columnsAndConstants.apply(args.head), 1)
case 2 => lag(columnsAndConstants.apply(args.head),
constantMap.apply(args(1)).asInstanceOf[Number].intValue())
case 3 =>
lag(columnsAndConstants.apply(args.head),
constantMap.apply(args(1)).asInstanceOf[Number].intValue(),
constantValue(rel, constantMap, args(2), visitor))
}
case "NTILE" =>
ntile(constantMap
.apply(
agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex)
.asInstanceOf[Number].intValue())
case "COUNT" =>
count(
if (agg.operands.isEmpty) {
k_lit(1)
} else {
columnsAndConstants.apply(
agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex)
}
)
case "MAX" =>
max(
columnsAndConstants.apply(
agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex))
case x if opName.contains("SUM") =>
sum(
columnsAndConstants.apply(
agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex))
case "MIN" =>
min(
columnsAndConstants.apply(
agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex))
case "AVG" =>
avg(
columnsAndConstants.apply(
agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex))
}
windowCount = windowCount + 1
val alias = s"${System.identityHashCode(rel)}_window_" + windowCount
if (windowDesc == null) {
func.over().alias(alias)
} else {
func.over(windowDesc).alias(alias)
}
}
}
val selectColumn = columns ++ windows
val windowPlan = SparkOperation.project(selectColumn, plan)
logInfo(s"Gen window cost Time :${System.currentTimeMillis() - start} ")
windowPlan
}