in paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonSparkWriter.scala [80:275]
def write(data: DataFrame): Seq[CommitMessage] = {
val sparkSession = data.sparkSession
import sparkSession.implicits._
val withInitBucketCol = bucketMode match {
case BUCKET_UNAWARE => data
case CROSS_PARTITION if !data.schema.fieldNames.contains(ROW_KIND_COL) =>
data
.withColumn(ROW_KIND_COL, lit(RowKind.INSERT.toByteValue))
.withColumn(BUCKET_COL, lit(-1))
case _ => data.withColumn(BUCKET_COL, lit(-1))
}
val rowKindColIdx = SparkRowUtils.getFieldIndex(withInitBucketCol.schema, ROW_KIND_COL)
val bucketColIdx = SparkRowUtils.getFieldIndex(withInitBucketCol.schema, BUCKET_COL)
val encoderGroupWithBucketCol = EncoderSerDeGroup(withInitBucketCol.schema)
def newWrite(): SparkTableWrite = new SparkTableWrite(writeBuilder, rowType, rowKindColIdx)
def sparkParallelism = {
val defaultParallelism = sparkSession.sparkContext.defaultParallelism
val numShufflePartitions = sparkSession.sessionState.conf.numShufflePartitions
Math.max(defaultParallelism, numShufflePartitions)
}
def writeWithoutBucket(dataFrame: DataFrame): Dataset[Array[Byte]] = {
dataFrame.mapPartitions {
iter =>
{
val write = newWrite()
try {
iter.foreach(row => write.write(row))
write.finish()
} finally {
write.close()
}
}
}
}
def writeWithBucket(dataFrame: DataFrame): Dataset[Array[Byte]] = {
dataFrame.mapPartitions {
iter =>
{
val write = newWrite()
try {
iter.foreach(row => write.write(row, row.getInt(bucketColIdx)))
write.finish()
} finally {
write.close()
}
}
}
}
def writeWithBucketProcessor(
dataFrame: DataFrame,
processor: BucketProcessor[Row]): Dataset[Array[Byte]] = {
val repartitioned = repartitionByPartitionsAndBucket(
dataFrame
.mapPartitions(processor.processPartition)(encoderGroupWithBucketCol.encoder)
.toDF())
writeWithBucket(repartitioned)
}
def writeWithBucketAssigner(
dataFrame: DataFrame,
funcFactory: () => Row => Int): Dataset[Array[Byte]] = {
dataFrame.mapPartitions {
iter =>
{
val assigner = funcFactory.apply()
val write = newWrite()
try {
iter.foreach(row => write.write(row, assigner.apply(row)))
write.finish()
} finally {
write.close()
}
}
}
}
val written: Dataset[Array[Byte]] = bucketMode match {
case CROSS_PARTITION =>
// Topology: input -> bootstrap -> shuffle by key hash -> bucket-assigner -> shuffle by partition & bucket
val rowType = SparkTypeUtils.toPaimonType(withInitBucketCol.schema).asInstanceOf[RowType]
val assignerParallelism = Option(table.coreOptions.dynamicBucketAssignerParallelism)
.map(_.toInt)
.getOrElse(sparkParallelism)
val bootstrapped = bootstrapAndRepartitionByKeyHash(
withInitBucketCol,
assignerParallelism,
rowKindColIdx,
rowType)
val globalDynamicBucketProcessor =
GlobalDynamicBucketProcessor(
table,
rowType,
assignerParallelism,
encoderGroupWithBucketCol)
val repartitioned = repartitionByPartitionsAndBucket(
sparkSession.createDataFrame(
bootstrapped.mapPartitions(globalDynamicBucketProcessor.processPartition),
withInitBucketCol.schema))
writeWithBucket(repartitioned)
case HASH_DYNAMIC =>
val assignerParallelism = Option(table.coreOptions.dynamicBucketAssignerParallelism)
.map(_.toInt)
.getOrElse(sparkParallelism)
val numAssigners = Option(table.coreOptions.dynamicBucketInitialBuckets)
.map(initialBuckets => Math.min(initialBuckets.toInt, assignerParallelism))
.getOrElse(assignerParallelism)
def partitionByKey(): DataFrame = {
repartitionByKeyPartitionHash(
sparkSession,
withInitBucketCol,
assignerParallelism,
numAssigners)
}
if (table.snapshotManager().latestSnapshot() == null) {
// bootstrap mode
// Topology: input -> shuffle by special key & partition hash -> bucket-assigner
writeWithBucketAssigner(
partitionByKey(),
() => {
val extractor = new RowPartitionKeyExtractor(table.schema)
val assigner =
new SimpleHashBucketAssigner(
numAssigners,
TaskContext.getPartitionId(),
table.coreOptions.dynamicBucketTargetRowNum,
table.coreOptions.dynamicBucketMaxBuckets
)
row => {
val sparkRow = new SparkRow(rowType, row)
assigner.assign(
extractor.partition(sparkRow),
extractor.trimmedPrimaryKey(sparkRow).hashCode)
}
}
)
} else {
// Topology: input -> shuffle by special key & partition hash -> bucket-assigner -> shuffle by partition & bucket
writeWithBucketProcessor(
partitionByKey(),
DynamicBucketProcessor(
table,
bucketColIdx,
assignerParallelism,
numAssigners,
encoderGroupWithBucketCol)
)
}
case BUCKET_UNAWARE =>
// Topology: input ->
writeWithoutBucket(data)
case HASH_FIXED =>
if (table.bucketSpec().getNumBuckets == POSTPONE_BUCKET) {
writeWithoutBucket(data)
} else if (paimonExtensionEnabled && BucketFunction.supportsTable(table)) {
// Topology: input -> shuffle by partition & bucket
val bucketNumber = table.coreOptions().bucket()
val bucketKeyCol = tableSchema
.bucketKeys()
.asScala
.map(tableSchema.fieldNames().indexOf(_))
.map(x => col(data.schema.fieldNames(x)))
.toSeq
val args = Seq(lit(bucketNumber)) ++ bucketKeyCol
val repartitioned =
repartitionByPartitionsAndBucket(
data.withColumn(BUCKET_COL, call_udf(BucketExpression.FIXED_BUCKET, args: _*)))
writeWithBucket(repartitioned)
} else {
// Topology: input -> bucket-assigner -> shuffle by partition & bucket
writeWithBucketProcessor(
withInitBucketCol,
CommonBucketProcessor(table, bucketColIdx, encoderGroupWithBucketCol))
}
case _ =>
throw new UnsupportedOperationException(s"Spark doesn't support $bucketMode mode.")
}
written
.collect()
.map(deserializeCommitMessage(serializer, _))
.toSeq
}