in paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala [61:155]
override def run(sparkSession: SparkSession): Seq[Row] = {
import sparkSession.implicits._
if (mergeSchema) {
val allowExplicitCast = options.get(SparkConnectorOptions.EXPLICIT_CAST)
mergeAndCommitSchema(data.schema, allowExplicitCast)
}
val (dynamicPartitionOverwriteMode, overwritePartition) = parseSaveMode()
// use the extra options to rebuild the table object
updateTableWithOptions(
Map(DYNAMIC_PARTITION_OVERWRITE.key -> dynamicPartitionOverwriteMode.toString))
val primaryKeyCols = tableSchema.trimmedPrimaryKeys().asScala.map(col)
val partitionCols = tableSchema.partitionKeys().asScala.map(col)
val dataEncoder = RowEncoder.apply(data.schema).resolveAndBind()
val originFromRow = dataEncoder.createDeserializer()
// append _bucket_ column as placeholder
val withBucketCol = data.withColumn(BUCKET_COL, lit(-1))
val bucketColIdx = withBucketCol.schema.size - 1
val withBucketDataEncoder = RowEncoder.apply(withBucketCol.schema).resolveAndBind()
val toRow = withBucketDataEncoder.createSerializer()
val fromRow = withBucketDataEncoder.createDeserializer()
def repartitionByBucket(ds: Dataset[Row]) = {
ds.toDF().repartition(partitionCols ++ Seq(col(BUCKET_COL)): _*)
}
val rowType = table.rowType()
val writeBuilder = table.newBatchWriteBuilder()
val df =
bucketMode match {
case BucketMode.DYNAMIC =>
val partitioned = if (primaryKeyCols.nonEmpty) {
// Make sure that the records with the same bucket values is within a task.
withBucketCol.repartition(primaryKeyCols: _*)
} else {
withBucketCol
}
val numSparkPartitions = partitioned.rdd.getNumPartitions
val dynamicBucketProcessor =
DynamicBucketProcessor(table, rowType, bucketColIdx, numSparkPartitions, toRow, fromRow)
repartitionByBucket(
partitioned.mapPartitions(dynamicBucketProcessor.processPartition)(
withBucketDataEncoder))
case BucketMode.UNAWARE =>
val unawareBucketProcessor = UnawareBucketProcessor(bucketColIdx, toRow, fromRow)
withBucketCol
.mapPartitions(unawareBucketProcessor.processPartition)(withBucketDataEncoder)
.toDF()
case BucketMode.FIXED =>
val commonBucketProcessor =
CommonBucketProcessor(writeBuilder, bucketColIdx, toRow, fromRow)
repartitionByBucket(
withBucketCol.mapPartitions(commonBucketProcessor.processPartition)(
withBucketDataEncoder))
}
val commitMessages = df
.mapPartitions {
iter =>
val write = writeBuilder.newWrite()
write.withIOManager(createIOManager)
try {
iter.foreach {
row =>
val bucket = row.getInt(bucketColIdx)
val bucketColDropped = originFromRow(toRow(row))
write.write(new DynamicBucketRow(new SparkRow(rowType, bucketColDropped), bucket))
}
val serializer = new CommitMessageSerializer
write.prepareCommit().asScala.map(serializer.serialize).toIterator
} finally {
write.close()
}
}
.collect()
.map(deserializeCommitMessage(serializer, _))
try {
val tableCommit = if (overwritePartition == null) {
writeBuilder.newCommit()
} else {
writeBuilder.withOverwrite(overwritePartition.asJava).newCommit()
}
tableCommit.commit(commitMessages.toList.asJava)
} catch {
case e: Throwable => throw new RuntimeException(e);
}
Seq.empty
}