def write()

in paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonSparkWriter.scala [80:275]


  def write(data: DataFrame): Seq[CommitMessage] = {
    val sparkSession = data.sparkSession
    import sparkSession.implicits._

    val withInitBucketCol = bucketMode match {
      case BUCKET_UNAWARE => data
      case CROSS_PARTITION if !data.schema.fieldNames.contains(ROW_KIND_COL) =>
        data
          .withColumn(ROW_KIND_COL, lit(RowKind.INSERT.toByteValue))
          .withColumn(BUCKET_COL, lit(-1))
      case _ => data.withColumn(BUCKET_COL, lit(-1))
    }
    val rowKindColIdx = SparkRowUtils.getFieldIndex(withInitBucketCol.schema, ROW_KIND_COL)
    val bucketColIdx = SparkRowUtils.getFieldIndex(withInitBucketCol.schema, BUCKET_COL)
    val encoderGroupWithBucketCol = EncoderSerDeGroup(withInitBucketCol.schema)

    def newWrite(): SparkTableWrite = new SparkTableWrite(writeBuilder, rowType, rowKindColIdx)

    def sparkParallelism = {
      val defaultParallelism = sparkSession.sparkContext.defaultParallelism
      val numShufflePartitions = sparkSession.sessionState.conf.numShufflePartitions
      Math.max(defaultParallelism, numShufflePartitions)
    }

    def writeWithoutBucket(dataFrame: DataFrame): Dataset[Array[Byte]] = {
      dataFrame.mapPartitions {
        iter =>
          {
            val write = newWrite()
            try {
              iter.foreach(row => write.write(row))
              write.finish()
            } finally {
              write.close()
            }
          }
      }
    }

    def writeWithBucket(dataFrame: DataFrame): Dataset[Array[Byte]] = {
      dataFrame.mapPartitions {
        iter =>
          {
            val write = newWrite()
            try {
              iter.foreach(row => write.write(row, row.getInt(bucketColIdx)))
              write.finish()
            } finally {
              write.close()
            }
          }
      }
    }

    def writeWithBucketProcessor(
        dataFrame: DataFrame,
        processor: BucketProcessor[Row]): Dataset[Array[Byte]] = {
      val repartitioned = repartitionByPartitionsAndBucket(
        dataFrame
          .mapPartitions(processor.processPartition)(encoderGroupWithBucketCol.encoder)
          .toDF())
      writeWithBucket(repartitioned)
    }

    def writeWithBucketAssigner(
        dataFrame: DataFrame,
        funcFactory: () => Row => Int): Dataset[Array[Byte]] = {
      dataFrame.mapPartitions {
        iter =>
          {
            val assigner = funcFactory.apply()
            val write = newWrite()
            try {
              iter.foreach(row => write.write(row, assigner.apply(row)))
              write.finish()
            } finally {
              write.close()
            }
          }
      }
    }

    val written: Dataset[Array[Byte]] = bucketMode match {
      case CROSS_PARTITION =>
        // Topology: input -> bootstrap -> shuffle by key hash -> bucket-assigner -> shuffle by partition & bucket
        val rowType = SparkTypeUtils.toPaimonType(withInitBucketCol.schema).asInstanceOf[RowType]
        val assignerParallelism = Option(table.coreOptions.dynamicBucketAssignerParallelism)
          .map(_.toInt)
          .getOrElse(sparkParallelism)
        val bootstrapped = bootstrapAndRepartitionByKeyHash(
          withInitBucketCol,
          assignerParallelism,
          rowKindColIdx,
          rowType)

        val globalDynamicBucketProcessor =
          GlobalDynamicBucketProcessor(
            table,
            rowType,
            assignerParallelism,
            encoderGroupWithBucketCol)
        val repartitioned = repartitionByPartitionsAndBucket(
          sparkSession.createDataFrame(
            bootstrapped.mapPartitions(globalDynamicBucketProcessor.processPartition),
            withInitBucketCol.schema))

        writeWithBucket(repartitioned)

      case HASH_DYNAMIC =>
        val assignerParallelism = Option(table.coreOptions.dynamicBucketAssignerParallelism)
          .map(_.toInt)
          .getOrElse(sparkParallelism)
        val numAssigners = Option(table.coreOptions.dynamicBucketInitialBuckets)
          .map(initialBuckets => Math.min(initialBuckets.toInt, assignerParallelism))
          .getOrElse(assignerParallelism)

        def partitionByKey(): DataFrame = {
          repartitionByKeyPartitionHash(
            sparkSession,
            withInitBucketCol,
            assignerParallelism,
            numAssigners)
        }

        if (table.snapshotManager().latestSnapshot() == null) {
          // bootstrap mode
          // Topology: input -> shuffle by special key & partition hash -> bucket-assigner
          writeWithBucketAssigner(
            partitionByKey(),
            () => {
              val extractor = new RowPartitionKeyExtractor(table.schema)
              val assigner =
                new SimpleHashBucketAssigner(
                  numAssigners,
                  TaskContext.getPartitionId(),
                  table.coreOptions.dynamicBucketTargetRowNum,
                  table.coreOptions.dynamicBucketMaxBuckets
                )
              row => {
                val sparkRow = new SparkRow(rowType, row)
                assigner.assign(
                  extractor.partition(sparkRow),
                  extractor.trimmedPrimaryKey(sparkRow).hashCode)
              }
            }
          )
        } else {
          // Topology: input -> shuffle by special key & partition hash -> bucket-assigner -> shuffle by partition & bucket
          writeWithBucketProcessor(
            partitionByKey(),
            DynamicBucketProcessor(
              table,
              bucketColIdx,
              assignerParallelism,
              numAssigners,
              encoderGroupWithBucketCol)
          )
        }

      case BUCKET_UNAWARE =>
        // Topology: input ->
        writeWithoutBucket(data)

      case HASH_FIXED =>
        if (table.bucketSpec().getNumBuckets == POSTPONE_BUCKET) {
          writeWithoutBucket(data)
        } else if (paimonExtensionEnabled && BucketFunction.supportsTable(table)) {
          // Topology: input -> shuffle by partition & bucket
          val bucketNumber = table.coreOptions().bucket()
          val bucketKeyCol = tableSchema
            .bucketKeys()
            .asScala
            .map(tableSchema.fieldNames().indexOf(_))
            .map(x => col(data.schema.fieldNames(x)))
            .toSeq
          val args = Seq(lit(bucketNumber)) ++ bucketKeyCol
          val repartitioned =
            repartitionByPartitionsAndBucket(
              data.withColumn(BUCKET_COL, call_udf(BucketExpression.FIXED_BUCKET, args: _*)))
          writeWithBucket(repartitioned)
        } else {
          // Topology: input -> bucket-assigner -> shuffle by partition & bucket
          writeWithBucketProcessor(
            withInitBucketCol,
            CommonBucketProcessor(table, bucketColIdx, encoderGroupWithBucketCol))
        }

      case _ =>
        throw new UnsupportedOperationException(s"Spark doesn't support $bucketMode mode.")
    }

    written
      .collect()
      .map(deserializeCommitMessage(serializer, _))
      .toSeq
  }