def write()

in backends-clickhouse/src-delta-32/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala [53:299]


  def write(
      sparkSession: SparkSession,
      plan: SparkPlan,
      fileFormat: FileFormat,
      committer: FileCommitProtocol,
      outputSpec: OutputSpec,
      hadoopConf: Configuration,
      partitionColumns: Seq[Attribute],
      bucketSpec: Option[BucketSpec],
      statsTrackers: Seq[WriteJobStatsTracker],
      options: Map[String, String],
      constraints: Seq[Constraint],
      numStaticPartitionCols: Int = 0): Set[String] = {

    val nativeEnabled =
      "true" == sparkSession.sparkContext.getLocalProperty("isNativeApplicable")
    val staticPartitionWriteOnly =
      "true" == sparkSession.sparkContext.getLocalProperty("staticPartitionWriteOnly")

    if (nativeEnabled) {
      logInfo("Use Gluten partition write for hive")
      assert(plan.isInstanceOf[IFakeRowAdaptor])
    }

    val job = Job.getInstance(hadoopConf)
    job.setOutputKeyClass(classOf[Void])
    job.setOutputValueClass(classOf[InternalRow])
    FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))

    val partitionSet = AttributeSet(partitionColumns)
    // cleanup the internal metadata information of
    // the file source metadata attribute if any before write out
    val finalOutputSpec = outputSpec.copy(outputColumns = outputSpec.outputColumns
      .map(FileSourceMetadataAttribute.cleanupFileSourceMetadataInformation))
    val dataColumns = finalOutputSpec.outputColumns.filterNot(partitionSet.contains)

    var needConvert = false
    val projectList: Seq[NamedExpression] = plan.output.map {
      case p if partitionSet.contains(p) && p.dataType == StringType && p.nullable =>
        needConvert = true
        Alias(Empty2Null(p), p.name)()
      case attr => attr
    }

    val empty2NullPlan = if (staticPartitionWriteOnly && nativeEnabled) {
      // Velox backend only support static partition write.
      // And no need to add sort operator for static partition write.
      plan
    } else {
      if (needConvert) ProjectExec(projectList, plan) else plan
    }

    val writerBucketSpec = bucketSpec.map {
      spec =>
        val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)

        if (
          options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, "false") ==
            "true"
        ) {
          // Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression.
          // Without the extra bitwise-and operation, we can get wrong bucket id when hash value of
          // columns is negative. See Hive implementation in
          // `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`.
          val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue))
          val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets))

          // The bucket file name prefix is following Hive, Presto and Trino conversion, so this
          // makes sure Hive bucketed table written by Spark, can be read by other SQL engines.
          //
          // Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`.
          // Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`.
          val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_"
          WriterBucketSpec(bucketIdExpression, fileNamePrefix)
        } else {
          // Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id
          // expression, so that we can guarantee the data distribution is same between shuffle and
          // bucketed data source, which enables us to only shuffle one side when join a bucketed
          // table and a normal one.
          val bucketIdExpression =
            HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
          WriterBucketSpec(bucketIdExpression, (_: Int) => "")
        }
    }
    val sortColumns = bucketSpec.toSeq.flatMap {
      spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
    }

    val caseInsensitiveOptions = CaseInsensitiveMap(options)

    val dataSchema = dataColumns.toStructType
    DataSourceUtils.verifySchema(fileFormat, dataSchema)
    // Note: prepareWrite has side effect. It sets "job".
    val outputWriterFactory =
      fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataSchema)

    val description = new WriteJobDescription(
      uuid = UUID.randomUUID.toString,
      serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
      outputWriterFactory = outputWriterFactory,
      allColumns = finalOutputSpec.outputColumns,
      dataColumns = dataColumns,
      partitionColumns = partitionColumns,
      bucketSpec = writerBucketSpec,
      path = finalOutputSpec.outputPath,
      customPartitionLocations = finalOutputSpec.customPartitionLocations,
      maxRecordsPerFile = caseInsensitiveOptions
        .get("maxRecordsPerFile")
        .map(_.toLong)
        .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile),
      timeZoneId = caseInsensitiveOptions
        .get(DateTimeUtils.TIMEZONE_OPTION)
        .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone),
      statsTrackers = statsTrackers
    )

    // We should first sort by partition columns, then bucket id, and finally sorting columns.
    val requiredOrdering = partitionColumns.drop(numStaticPartitionCols) ++
      writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns
    // the sort order doesn't matter
    val actualOrdering = empty2NullPlan.outputOrdering.map(_.child)
    val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
      false
    } else {
      requiredOrdering.zip(actualOrdering).forall {
        case (requiredOrder, childOutputOrder) =>
          requiredOrder.semanticEquals(childOutputOrder)
      }
    }

    SQLExecution.checkSQLExecutionId(sparkSession)

    // propagate the description UUID into the jobs, so that committers
    // get an ID guaranteed to be unique.
    job.getConfiguration.set("spark.sql.sources.writeJobUUID", description.uuid)

    // This call shouldn't be put into the `try` block below because it only initializes and
    // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
    committer.setupJob(job)

    def nativeWrap(plan: SparkPlan) = {
      var wrapped: SparkPlan = plan
      if (writerBucketSpec.isDefined) {
        // We need to add the bucket id expression to the output of the sort plan,
        // so that we can use backend to calculate the bucket id for each row.
        wrapped = ProjectExec(
          wrapped.output :+ Alias(writerBucketSpec.get.bucketIdExpression, "__bucket_value__")(),
          wrapped)
        // TODO: to optimize, bucket value is computed twice here
      }

      val nativeFormat = sparkSession.sparkContext.getLocalProperty("nativeFormat")
      (GlutenFormatFactory(nativeFormat).executeWriterWrappedSparkPlan(wrapped), None)
    }

    try {
      val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) {
        if (!nativeEnabled || (staticPartitionWriteOnly && nativeEnabled)) {
          (empty2NullPlan.execute(), None)
        } else {
          nativeWrap(empty2NullPlan)
        }
      } else {
        // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
        // the physical plan may have different attribute ids due to optimizer removing some
        // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
        val orderingExpr = bindReferences(
          requiredOrdering.map(SortOrder(_, Ascending)),
          finalOutputSpec.outputColumns)
        val sortPlan = SortExec(orderingExpr, global = false, child = empty2NullPlan)

        val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters
        var concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty
        if (nativeEnabled && concurrentWritersEnabled) {
          log.warn(
            s"spark.sql.maxConcurrentOutputFileWriters(being set to $maxWriters) will be " +
              "ignored when native writer is being active. No concurrent Writers.")
          concurrentWritersEnabled = false
        }

        if (concurrentWritersEnabled) {
          (
            empty2NullPlan.execute(),
            Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter())))
        } else {
          if (staticPartitionWriteOnly && nativeEnabled) {
            // remove the sort operator for static partition write.
            (empty2NullPlan.execute(), None)
          } else {
            if (!nativeEnabled) {
              (sortPlan.execute(), None)
            } else {
              nativeWrap(sortPlan)
            }
          }
        }
      }

      // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
      // partition rdd to make sure we at least set up one write task to write the metadata.
      val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) {
        sparkSession.sparkContext.parallelize(Array.empty[InternalRow], 1)
      } else {
        rdd
      }

      val jobIdInstant = new Date().getTime
      val ret = new Array[WriteTaskResult](rddWithNonEmptyPartitions.partitions.length)
      sparkSession.sparkContext.runJob(
        rddWithNonEmptyPartitions,
        (taskContext: TaskContext, iter: Iterator[InternalRow]) => {
          executeTask(
            description = description,
            jobIdInstant = jobIdInstant,
            sparkStageId = taskContext.stageId(),
            sparkPartitionId = taskContext.partitionId(),
            sparkAttemptNumber = taskContext.taskAttemptId().toInt & Integer.MAX_VALUE,
            committer,
            iterator = iter,
            concurrentOutputWriterSpec = concurrentOutputWriterSpec
          )
        },
        rddWithNonEmptyPartitions.partitions.indices,
        (index, res: WriteTaskResult) => {
          committer.onTaskCommit(res.commitMsg)
          ret(index) = res
        }
      )

      val commitMsgs = ret.map(_.commitMsg)

      logInfo(s"Start to commit write Job ${description.uuid}.")
      val (_, duration) = Utils.timeTakenMs(committer.commitJob(job, commitMsgs))
      logInfo(s"Write Job ${description.uuid} committed. Elapsed time: $duration ms.")

      processStats(description.statsTrackers, ret.map(_.summary.stats), duration)
      logInfo(s"Finished processing stats for write job ${description.uuid}.")

      // return a set of all the partition paths that were updated during this job
      ret.map(_.summary.updatedPartitions).reduceOption(_ ++ _).getOrElse(Set.empty)
    } catch {
      case cause: Throwable =>
        logError(s"Aborting job ${description.uuid}.", cause)
        committer.abortJob(job)
        throw cause
    }
  }