in hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/execution/RangeSample.scala [256:449]
def sortDataFrameBySample(df: DataFrame, layoutOptStrategy: LayoutOptimizationStrategy, orderByCols: Seq[String], targetPartitionsCount: Int): DataFrame = {
val spark = df.sparkSession
val columnsMap = df.schema.fields.map(item => (item.name, item)).toMap
val fieldNum = df.schema.fields.length
val checkCols = orderByCols.filter(col => columnsMap(col) != null)
if (orderByCols.isEmpty || checkCols.isEmpty) {
df
} else {
val zFields = orderByCols.map { col =>
val newCol = columnsMap(col)
if (newCol == null) {
(-1, null)
} else {
newCol.dataType match {
case LongType | DoubleType | FloatType | StringType | IntegerType | DateType | TimestampType | ShortType | ByteType =>
(df.schema.fields.indexOf(newCol), newCol)
case d: DecimalType =>
(df.schema.fields.indexOf(newCol), newCol)
case _ =>
(-1, null)
}
}
}.filter(_._1 != -1)
// Complex type found, use createZIndexedDataFrameByRange
if (zFields.length != orderByCols.length) {
return sortDataFrameBySampleSupportAllTypes(df, orderByCols, targetPartitionsCount)
}
val rawRdd = df.rdd
val sampleRdd = rawRdd.map { row =>
val values = zFields.map { case (index, field) =>
field.dataType match {
case LongType =>
if (row.isNullAt(index)) Long.MaxValue else row.getLong(index)
case DoubleType =>
if (row.isNullAt(index)) Long.MaxValue else java.lang.Double.doubleToLongBits(row.getDouble(index))
case IntegerType =>
if (row.isNullAt(index)) Long.MaxValue else row.getInt(index).toLong
case FloatType =>
if (row.isNullAt(index)) Long.MaxValue else java.lang.Double.doubleToLongBits(row.getFloat(index).toDouble)
case StringType =>
if (row.isNullAt(index)) "" else row.getString(index)
case DateType =>
if (row.isNullAt(index)) Long.MaxValue else row.getDate(index).getTime
case TimestampType =>
if (row.isNullAt(index)) Long.MaxValue else row.getTimestamp(index).getTime
case ByteType =>
if (row.isNullAt(index)) Long.MaxValue else row.getByte(index).toLong
case ShortType =>
if (row.isNullAt(index)) Long.MaxValue else row.getShort(index).toLong
case d: DecimalType =>
if (row.isNullAt(index)) Long.MaxValue else row.getDecimal(index).longValue()
case _ =>
null
}
}.filter(v => v != null).toArray
(values, null)
}
val zOrderBounds = df.sparkSession.sessionState.conf.getConfString(
HoodieClusteringConfig.LAYOUT_OPTIMIZE_BUILD_CURVE_SAMPLE_SIZE.key,
HoodieClusteringConfig.LAYOUT_OPTIMIZE_BUILD_CURVE_SAMPLE_SIZE.defaultValue.toString).toInt
val sample = new RangeSample(zOrderBounds, sampleRdd)
val rangeBounds = sample.getRangeBounds()
if (rangeBounds.size <= 1)
return df
val sampleBounds = {
val candidateColNumber = rangeBounds.head._1.length
(0 to candidateColNumber - 1).map { i =>
val colRangeBound = rangeBounds.map(x => (x._1(i), x._2))
if (colRangeBound.head._1.isInstanceOf[String]) {
sample.determineBound(colRangeBound.asInstanceOf[ArrayBuffer[(String, Float)]], math.min(zOrderBounds, rangeBounds.length), Ordering[String])
} else {
sample.determineBound(colRangeBound.asInstanceOf[ArrayBuffer[(Long, Float)]], math.min(zOrderBounds, rangeBounds.length), Ordering[Long])
}
}
}
// expand bounds.
// maybe it's better to use the value of "spark.zorder.bounds.number" as maxLength,
// however this will lead to extra time costs when all zorder cols distinct count values are less then "spark.zorder.bounds.number"
val maxLength = sampleBounds.map(_.length).max
val expandSampleBoundsWithFactor = sampleBounds.map { bound =>
val fillFactor = maxLength / bound.size
val newBound = new Array[Double](bound.length * fillFactor)
if (bound.isInstanceOf[Array[Long]] && fillFactor > 1) {
val longBound = bound.asInstanceOf[Array[Long]]
for (i <- 0 to bound.length - 1) {
for (j <- 0 to fillFactor - 1) {
// sample factor should not be too large, so it's ok to use 1 / fillfactor as slice
newBound(j + i*(fillFactor)) = longBound(i) + (j + 1) * (1 / fillFactor.toDouble)
}
}
(newBound, fillFactor)
} else {
(bound, 0)
}
}
val boundBroadCast = spark.sparkContext.broadcast(expandSampleBoundsWithFactor)
val indexRdd = rawRdd.mapPartitions { iter =>
val expandBoundsWithFactor = boundBroadCast.value
val maxBoundNum = expandBoundsWithFactor.map(_._1.length).max
val longDecisionBound = new RawDecisionBound(Ordering[Long])
val doubleDecisionBound = new RawDecisionBound(Ordering[Double])
val stringDecisionBound = new RawDecisionBound(Ordering[String])
import java.util.concurrent.ThreadLocalRandom
val threadLocalRandom = ThreadLocalRandom.current
def getRank(rawIndex: Int, value: Long, isNull: Boolean): Int = {
val (expandBound, factor) = expandBoundsWithFactor(rawIndex)
if (isNull) {
expandBound.length + 1
} else {
if (factor > 1) {
doubleDecisionBound.getBound(value + (threadLocalRandom.nextInt(factor) + 1)*(1 / factor.toDouble), expandBound.asInstanceOf[Array[Double]])
} else {
longDecisionBound.getBound(value, expandBound.asInstanceOf[Array[Long]])
}
}
}
val hilbertCurve = if (layoutOptStrategy == LayoutOptimizationStrategy.HILBERT)
Some(HilbertCurve.bits(32).dimensions(zFields.length))
else
None
iter.map { row =>
val values = zFields.zipWithIndex.map { case ((index, field), rawIndex) =>
field.dataType match {
case LongType =>
val isNull = row.isNullAt(index)
getRank(rawIndex, if (isNull) 0 else row.getLong(index), isNull)
case DoubleType =>
val isNull = row.isNullAt(index)
getRank(rawIndex, if (isNull) 0 else java.lang.Double.doubleToLongBits(row.getDouble(index)), isNull)
case IntegerType =>
val isNull = row.isNullAt(index)
getRank(rawIndex, if (isNull) 0 else row.getInt(index).toLong, isNull)
case FloatType =>
val isNull = row.isNullAt(index)
getRank(rawIndex, if (isNull) 0 else java.lang.Double.doubleToLongBits(row.getFloat(index).toDouble), isNull)
case StringType =>
val factor = maxBoundNum.toDouble / expandBoundsWithFactor(rawIndex)._1.length
if (row.isNullAt(index)) {
maxBoundNum + 1
} else {
val currentRank = stringDecisionBound.getBound(row.getString(index), expandBoundsWithFactor(rawIndex)._1.asInstanceOf[Array[String]])
if (factor > 1) {
(currentRank*factor).toInt + threadLocalRandom.nextInt(factor.toInt)
} else {
currentRank
}
}
case DateType =>
val isNull = row.isNullAt(index)
getRank(rawIndex, if (isNull) 0 else row.getDate(index).getTime, isNull)
case TimestampType =>
val isNull = row.isNullAt(index)
getRank(rawIndex, if (isNull) 0 else row.getTimestamp(index).getTime, isNull)
case ByteType =>
val isNull = row.isNullAt(index)
getRank(rawIndex, if (isNull) 0 else row.getByte(index).toLong, isNull)
case ShortType =>
val isNull = row.isNullAt(index)
getRank(rawIndex, if (isNull) 0 else row.getShort(index).toLong, isNull)
case d: DecimalType =>
val isNull = row.isNullAt(index)
getRank(rawIndex, if (isNull) 0 else row.getDecimal(index).longValue(), isNull)
case _ =>
-1
}
}.filter(v => v != -1)
val mapValues = layoutOptStrategy match {
case LayoutOptimizationStrategy.HILBERT =>
HilbertCurveUtils.indexBytes(hilbertCurve.get, values.map(_.toLong).toArray, 32)
case LayoutOptimizationStrategy.ZORDER =>
BinaryUtil.interleaving(values.map(BinaryUtil.intTo8Byte(_)).toArray, 8)
}
Row.fromSeq(row.toSeq ++ Seq(mapValues))
}
}.sortBy(x => ByteArraySorting(x.getAs[Array[Byte]](fieldNum)), numPartitions = targetPartitionsCount)
val newDF = df.sparkSession.createDataFrame(indexRdd, StructType(
df.schema.fields ++ Seq(
StructField(s"index",
BinaryType, false))
))
newDF.drop("index")
}
}