def sortDataFrameBySample()

in hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/execution/RangeSample.scala [256:449]


  def sortDataFrameBySample(df: DataFrame, layoutOptStrategy: LayoutOptimizationStrategy, orderByCols: Seq[String], targetPartitionsCount: Int): DataFrame = {
    val spark = df.sparkSession
    val columnsMap = df.schema.fields.map(item => (item.name, item)).toMap
    val fieldNum = df.schema.fields.length
    val checkCols = orderByCols.filter(col => columnsMap(col) != null)

    if (orderByCols.isEmpty || checkCols.isEmpty) {
      df
    } else {
      val zFields = orderByCols.map { col =>
        val newCol = columnsMap(col)
        if (newCol == null) {
          (-1, null)
        } else {
          newCol.dataType match {
            case LongType | DoubleType | FloatType | StringType | IntegerType | DateType | TimestampType | ShortType | ByteType =>
              (df.schema.fields.indexOf(newCol), newCol)
            case d: DecimalType =>
              (df.schema.fields.indexOf(newCol), newCol)
            case _ =>
              (-1, null)
          }
        }
      }.filter(_._1 != -1)
      // Complex type found, use createZIndexedDataFrameByRange
      if (zFields.length != orderByCols.length) {
        return sortDataFrameBySampleSupportAllTypes(df, orderByCols, targetPartitionsCount)
      }

      val rawRdd = df.rdd
      val sampleRdd = rawRdd.map { row =>
        val values = zFields.map { case (index, field) =>
          field.dataType match {
            case LongType =>
              if (row.isNullAt(index)) Long.MaxValue else row.getLong(index)
            case DoubleType =>
              if (row.isNullAt(index)) Long.MaxValue else java.lang.Double.doubleToLongBits(row.getDouble(index))
            case IntegerType =>
              if (row.isNullAt(index)) Long.MaxValue else row.getInt(index).toLong
            case FloatType =>
              if (row.isNullAt(index)) Long.MaxValue else java.lang.Double.doubleToLongBits(row.getFloat(index).toDouble)
            case StringType =>
              if (row.isNullAt(index)) "" else row.getString(index)
            case DateType =>
              if (row.isNullAt(index)) Long.MaxValue else row.getDate(index).getTime
            case TimestampType =>
              if (row.isNullAt(index)) Long.MaxValue else row.getTimestamp(index).getTime
            case ByteType =>
              if (row.isNullAt(index)) Long.MaxValue else row.getByte(index).toLong
            case ShortType =>
              if (row.isNullAt(index)) Long.MaxValue else row.getShort(index).toLong
            case d: DecimalType =>
              if (row.isNullAt(index)) Long.MaxValue else row.getDecimal(index).longValue()
            case _ =>
              null
          }
        }.filter(v => v != null).toArray
        (values, null)
      }
      val zOrderBounds = df.sparkSession.sessionState.conf.getConfString(
        HoodieClusteringConfig.LAYOUT_OPTIMIZE_BUILD_CURVE_SAMPLE_SIZE.key,
        HoodieClusteringConfig.LAYOUT_OPTIMIZE_BUILD_CURVE_SAMPLE_SIZE.defaultValue.toString).toInt
      val sample = new RangeSample(zOrderBounds, sampleRdd)
      val rangeBounds = sample.getRangeBounds()
      if (rangeBounds.size <= 1)
        return df
      val sampleBounds = {
        val candidateColNumber = rangeBounds.head._1.length
        (0 to candidateColNumber - 1).map { i =>
          val colRangeBound = rangeBounds.map(x => (x._1(i), x._2))

          if (colRangeBound.head._1.isInstanceOf[String]) {
            sample.determineBound(colRangeBound.asInstanceOf[ArrayBuffer[(String, Float)]], math.min(zOrderBounds, rangeBounds.length), Ordering[String])
          } else {
            sample.determineBound(colRangeBound.asInstanceOf[ArrayBuffer[(Long, Float)]], math.min(zOrderBounds, rangeBounds.length), Ordering[Long])
          }
        }
      }

      // expand bounds.
      // maybe it's better to use the value of "spark.zorder.bounds.number" as maxLength,
      // however this will lead to extra time costs when all zorder cols distinct count values are less then "spark.zorder.bounds.number"
      val maxLength = sampleBounds.map(_.length).max
      val expandSampleBoundsWithFactor = sampleBounds.map { bound =>
        val fillFactor = maxLength / bound.size
        val newBound = new Array[Double](bound.length * fillFactor)
        if (bound.isInstanceOf[Array[Long]] && fillFactor > 1) {
          val longBound = bound.asInstanceOf[Array[Long]]
          for (i <- 0 to bound.length - 1) {
            for (j <- 0 to fillFactor - 1) {
              // sample factor should not be too large, so it's ok to use 1 / fillfactor as slice
              newBound(j + i*(fillFactor)) = longBound(i) + (j + 1) * (1 / fillFactor.toDouble)
            }
          }
          (newBound, fillFactor)
        } else {
          (bound, 0)
        }
      }

      val boundBroadCast = spark.sparkContext.broadcast(expandSampleBoundsWithFactor)

      val indexRdd = rawRdd.mapPartitions { iter =>
        val expandBoundsWithFactor = boundBroadCast.value
        val maxBoundNum = expandBoundsWithFactor.map(_._1.length).max
        val longDecisionBound = new RawDecisionBound(Ordering[Long])
        val doubleDecisionBound = new RawDecisionBound(Ordering[Double])
        val stringDecisionBound = new RawDecisionBound(Ordering[String])
        import java.util.concurrent.ThreadLocalRandom
        val threadLocalRandom = ThreadLocalRandom.current

        def getRank(rawIndex: Int, value: Long, isNull: Boolean): Int = {
          val (expandBound, factor) = expandBoundsWithFactor(rawIndex)
          if (isNull) {
            expandBound.length + 1
          } else {
            if (factor > 1) {
              doubleDecisionBound.getBound(value + (threadLocalRandom.nextInt(factor) + 1)*(1 / factor.toDouble), expandBound.asInstanceOf[Array[Double]])
            } else {
              longDecisionBound.getBound(value, expandBound.asInstanceOf[Array[Long]])
            }
          }
        }

        val hilbertCurve = if (layoutOptStrategy == LayoutOptimizationStrategy.HILBERT)
          Some(HilbertCurve.bits(32).dimensions(zFields.length))
        else
          None

        iter.map { row =>
          val values = zFields.zipWithIndex.map { case ((index, field), rawIndex) =>
            field.dataType match {
              case LongType =>
                val isNull = row.isNullAt(index)
                getRank(rawIndex, if (isNull) 0 else row.getLong(index), isNull)
              case DoubleType =>
                val isNull = row.isNullAt(index)
                getRank(rawIndex, if (isNull) 0 else java.lang.Double.doubleToLongBits(row.getDouble(index)), isNull)
              case IntegerType =>
                val isNull = row.isNullAt(index)
                getRank(rawIndex, if (isNull) 0 else row.getInt(index).toLong, isNull)
              case FloatType =>
                val isNull = row.isNullAt(index)
                getRank(rawIndex, if (isNull) 0 else java.lang.Double.doubleToLongBits(row.getFloat(index).toDouble), isNull)
              case StringType =>
                val factor = maxBoundNum.toDouble / expandBoundsWithFactor(rawIndex)._1.length
                if (row.isNullAt(index)) {
                  maxBoundNum + 1
                } else {
                  val currentRank = stringDecisionBound.getBound(row.getString(index), expandBoundsWithFactor(rawIndex)._1.asInstanceOf[Array[String]])
                  if (factor > 1) {
                    (currentRank*factor).toInt + threadLocalRandom.nextInt(factor.toInt)
                  } else {
                    currentRank
                  }
                }
              case DateType =>
                val isNull = row.isNullAt(index)
                getRank(rawIndex, if (isNull) 0 else row.getDate(index).getTime, isNull)
              case TimestampType =>
                val isNull = row.isNullAt(index)
                getRank(rawIndex, if (isNull) 0 else row.getTimestamp(index).getTime, isNull)
              case ByteType =>
                val isNull = row.isNullAt(index)
                getRank(rawIndex, if (isNull) 0 else row.getByte(index).toLong, isNull)
              case ShortType =>
                val isNull = row.isNullAt(index)
                getRank(rawIndex, if (isNull) 0 else row.getShort(index).toLong, isNull)
              case d: DecimalType =>
                val isNull = row.isNullAt(index)
                getRank(rawIndex, if (isNull) 0 else row.getDecimal(index).longValue(), isNull)
              case _ =>
                -1
            }
          }.filter(v => v != -1)

          val mapValues = layoutOptStrategy match {
            case LayoutOptimizationStrategy.HILBERT =>
              HilbertCurveUtils.indexBytes(hilbertCurve.get, values.map(_.toLong).toArray, 32)
            case LayoutOptimizationStrategy.ZORDER =>
              BinaryUtil.interleaving(values.map(BinaryUtil.intTo8Byte(_)).toArray, 8)
          }

          Row.fromSeq(row.toSeq ++ Seq(mapValues))
        }
      }.sortBy(x => ByteArraySorting(x.getAs[Array[Byte]](fieldNum)), numPartitions = targetPartitionsCount)
      val newDF = df.sparkSession.createDataFrame(indexRdd, StructType(
        df.schema.fields ++ Seq(
          StructField(s"index",
            BinaryType, false))
      ))
      newDF.drop("index")
    }
  }