def window()

in src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/plan/WindowPlan.scala [58:245]


  def window(plan: LogicalPlan,
             rel: OlapWindowRel,
             datacontex: DataContext): LogicalPlan = {
    val start = System.currentTimeMillis()

    var windowCount = 0
    rel.groups.asScala.head.upperBound
    val columnSize = plan.output.size

    val columns = plan.output.map(c => col(c.name))
    val constantMap = rel.getConstants.asScala
      .map(_.getValue)
      .zipWithIndex
      .map { entry =>
        (entry._2 + columnSize, entry._1)
      }.toMap[Int, Any]
    val visitor = new SparderRexVisitor(plan,
      rel.getInput.getRowType,
      datacontex)
    val constants = rel.getConstants.asScala
      .map { constant =>
        k_lit(Literal.apply(constant.accept(visitor)))
      }
    val columnsAndConstants = columns ++ constants
    val windows = rel.groups.asScala
      .flatMap { group =>
        var isDateTimeFamilyType = false
        val fieldsNameToType = rel.getInput.getRowType.getFieldList.asScala.zipWithIndex
          .map {
            case (field, index) => index -> field.getType.getSqlTypeName.toString
          }.toMap

        fieldsNameToType.foreach(map =>
          if (SparderTypeUtil.isDateTimeFamilyType(map._2)) {
            isDateTimeFamilyType = true
          })

        var orderByColumns = group.orderKeys
          .asInstanceOf[RelCollationImpl]
          .getFieldCollations
          .asScala
          .map { fieldIndex =>
            var column = columns.apply(fieldIndex.getFieldIndex)
            if (!group.isRows && fieldsNameToType(fieldIndex.getFieldIndex).equalsIgnoreCase("timestamp")) {
              column = column.cast(LongType)
            }
            fieldIndex.direction match {
              case Direction.DESCENDING =>
                column = column.desc
              case Direction.STRICTLY_DESCENDING =>
                column = column.desc
              case Direction.ASCENDING =>
                column = column.asc
              case Direction.STRICTLY_ASCENDING =>
                column = column.asc
              case _ =>
            }
            column
          }
          .toList
        val partitionColumns = group.keys.asScala
          .map(fieldIndex => columns.apply(fieldIndex))
          .toSeq
        group.aggCalls.asScala.map { agg =>
          var windowDesc: WindowSpec = null
          val opName = agg.op.getName.toUpperCase(Locale.ROOT)
          val numberConstants = constantMap
            .filter(_._2.isInstanceOf[Number])
            .map { entry =>
              (entry._1, entry._2.asInstanceOf[Number])
            }.toMap
          var (lowerBound: Long, upperBound: Long) = buildRange(group, numberConstants, isDateTimeFamilyType, group.isRows)
          if (orderByColumns.nonEmpty) {

            windowDesc = Window.orderBy(orderByColumns: _*)
            if (!nonRangeSpecified.contains(opName)) {
              if (group.isRows || rowSpecified.contains(opName)) {
                windowDesc = windowDesc.rowsBetween(lowerBound, upperBound)
              } else {
                windowDesc = windowDesc.rangeBetween(lowerBound, upperBound)
              }
            }
          } else {
            if (sortSpecified.contains(opName)) {
              windowDesc = Window.orderBy(k_lit(1))
              if (!nonRangeSpecified.contains(opName)) {
                if (group.isRows || rowSpecified.contains(opName)) {
                  windowDesc = windowDesc.rowsBetween(lowerBound, upperBound)
                } else {
                  windowDesc = windowDesc.rangeBetween(lowerBound, upperBound)
                }
              }
            }
          }
          if (partitionColumns.nonEmpty) {
            windowDesc =
              if (windowDesc == null) Window.partitionBy(partitionColumns: _*)
              else windowDesc.partitionBy(partitionColumns: _*)
          }

          val func = opName match {
            case "ROW_NUMBER" =>
              row_number()
            case "RANK" =>
              rank()
            case "DENSE_RANK" =>
              dense_rank()
            case "FIRST_VALUE" =>
              first(
                columnsAndConstants.apply(
                  agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex))
            case "LAST_VALUE" =>
              last(
                columnsAndConstants.apply(
                  agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex))
            case "LEAD" =>
              val args =
                agg.operands.asScala.map(_.asInstanceOf[RexInputRef].getIndex)
              args.size match {
                // offset default value is 1 in spark
                case 1 => lead(columnsAndConstants.apply(args.head), 1)
                case 2 => lead(columnsAndConstants.apply(args.head),
                  constantMap.apply(args(1)).asInstanceOf[Number].intValue())
                case 3 =>
                  lead(columnsAndConstants.apply(args.head),
                    constantMap.apply(args(1)).asInstanceOf[Number].intValue(),
                    constantValue(rel, constantMap, args(2), visitor))
              }

            case "LAG" =>
              val args =
                agg.operands.asScala.map(_.asInstanceOf[RexInputRef].getIndex)
              args.size match {
                // offset default value is 1 in spark
                case 1 => lag(columnsAndConstants.apply(args.head), 1)
                case 2 => lag(columnsAndConstants.apply(args.head),
                  constantMap.apply(args(1)).asInstanceOf[Number].intValue())
                case 3 =>
                  lag(columnsAndConstants.apply(args.head),
                    constantMap.apply(args(1)).asInstanceOf[Number].intValue(),
                    constantValue(rel, constantMap, args(2), visitor))
              }
            case "NTILE" =>
              ntile(constantMap
                .apply(
                  agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex)
                .asInstanceOf[Number].intValue())
            case "COUNT" =>
              count(
                if (agg.operands.isEmpty) {
                  k_lit(1)
                } else {
                  columnsAndConstants.apply(
                    agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex)
                }
              )
            case "MAX" =>
              max(
                columnsAndConstants.apply(
                  agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex))
            case x if opName.contains("SUM") =>
              sum(
                columnsAndConstants.apply(
                  agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex))
            case "MIN" =>
              min(
                columnsAndConstants.apply(
                  agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex))
            case "AVG" =>
              avg(
                columnsAndConstants.apply(
                  agg.operands.asScala.head.asInstanceOf[RexInputRef].getIndex))
          }
          windowCount = windowCount + 1
          val alias = s"${System.identityHashCode(rel)}_window_" + windowCount
          if (windowDesc == null) {
            func.over().alias(alias)
          } else {
            func.over(windowDesc).alias(alias)
          }
        }

      }
    val selectColumn = columns ++ windows
    val windowPlan = SparkOperation.project(selectColumn, plan)
    logInfo(s"Gen window cost Time :${System.currentTimeMillis() - start} ")
    windowPlan
  }