override def run()

in paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala [61:155]


  override def run(sparkSession: SparkSession): Seq[Row] = {
    import sparkSession.implicits._

    if (mergeSchema) {
      val allowExplicitCast = options.get(SparkConnectorOptions.EXPLICIT_CAST)
      mergeAndCommitSchema(data.schema, allowExplicitCast)
    }

    val (dynamicPartitionOverwriteMode, overwritePartition) = parseSaveMode()
    // use the extra options to rebuild the table object
    updateTableWithOptions(
      Map(DYNAMIC_PARTITION_OVERWRITE.key -> dynamicPartitionOverwriteMode.toString))

    val primaryKeyCols = tableSchema.trimmedPrimaryKeys().asScala.map(col)
    val partitionCols = tableSchema.partitionKeys().asScala.map(col)

    val dataEncoder = RowEncoder.apply(data.schema).resolveAndBind()
    val originFromRow = dataEncoder.createDeserializer()

    // append _bucket_ column as placeholder
    val withBucketCol = data.withColumn(BUCKET_COL, lit(-1))
    val bucketColIdx = withBucketCol.schema.size - 1
    val withBucketDataEncoder = RowEncoder.apply(withBucketCol.schema).resolveAndBind()
    val toRow = withBucketDataEncoder.createSerializer()
    val fromRow = withBucketDataEncoder.createDeserializer()

    def repartitionByBucket(ds: Dataset[Row]) = {
      ds.toDF().repartition(partitionCols ++ Seq(col(BUCKET_COL)): _*)
    }

    val rowType = table.rowType()
    val writeBuilder = table.newBatchWriteBuilder()

    val df =
      bucketMode match {
        case BucketMode.DYNAMIC =>
          val partitioned = if (primaryKeyCols.nonEmpty) {
            // Make sure that the records with the same bucket values is within a task.
            withBucketCol.repartition(primaryKeyCols: _*)
          } else {
            withBucketCol
          }
          val numSparkPartitions = partitioned.rdd.getNumPartitions
          val dynamicBucketProcessor =
            DynamicBucketProcessor(table, rowType, bucketColIdx, numSparkPartitions, toRow, fromRow)
          repartitionByBucket(
            partitioned.mapPartitions(dynamicBucketProcessor.processPartition)(
              withBucketDataEncoder))
        case BucketMode.UNAWARE =>
          val unawareBucketProcessor = UnawareBucketProcessor(bucketColIdx, toRow, fromRow)
          withBucketCol
            .mapPartitions(unawareBucketProcessor.processPartition)(withBucketDataEncoder)
            .toDF()
        case BucketMode.FIXED =>
          val commonBucketProcessor =
            CommonBucketProcessor(writeBuilder, bucketColIdx, toRow, fromRow)
          repartitionByBucket(
            withBucketCol.mapPartitions(commonBucketProcessor.processPartition)(
              withBucketDataEncoder))
      }

    val commitMessages = df
      .mapPartitions {
        iter =>
          val write = writeBuilder.newWrite()
          write.withIOManager(createIOManager)
          try {
            iter.foreach {
              row =>
                val bucket = row.getInt(bucketColIdx)
                val bucketColDropped = originFromRow(toRow(row))
                write.write(new DynamicBucketRow(new SparkRow(rowType, bucketColDropped), bucket))
            }
            val serializer = new CommitMessageSerializer
            write.prepareCommit().asScala.map(serializer.serialize).toIterator
          } finally {
            write.close()
          }
      }
      .collect()
      .map(deserializeCommitMessage(serializer, _))

    try {
      val tableCommit = if (overwritePartition == null) {
        writeBuilder.newCommit()
      } else {
        writeBuilder.withOverwrite(overwritePartition.asJava).newCommit()
      }
      tableCommit.commit(commitMessages.toList.asJava)
    } catch {
      case e: Throwable => throw new RuntimeException(e);
    }

    Seq.empty
  }