in backends-clickhouse/src-delta-32/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala [53:299]
def write(
sparkSession: SparkSession,
plan: SparkPlan,
fileFormat: FileFormat,
committer: FileCommitProtocol,
outputSpec: OutputSpec,
hadoopConf: Configuration,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
statsTrackers: Seq[WriteJobStatsTracker],
options: Map[String, String],
constraints: Seq[Constraint],
numStaticPartitionCols: Int = 0): Set[String] = {
val nativeEnabled =
"true" == sparkSession.sparkContext.getLocalProperty("isNativeApplicable")
val staticPartitionWriteOnly =
"true" == sparkSession.sparkContext.getLocalProperty("staticPartitionWriteOnly")
if (nativeEnabled) {
logInfo("Use Gluten partition write for hive")
assert(plan.isInstanceOf[IFakeRowAdaptor])
}
val job = Job.getInstance(hadoopConf)
job.setOutputKeyClass(classOf[Void])
job.setOutputValueClass(classOf[InternalRow])
FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))
val partitionSet = AttributeSet(partitionColumns)
// cleanup the internal metadata information of
// the file source metadata attribute if any before write out
val finalOutputSpec = outputSpec.copy(outputColumns = outputSpec.outputColumns
.map(FileSourceMetadataAttribute.cleanupFileSourceMetadataInformation))
val dataColumns = finalOutputSpec.outputColumns.filterNot(partitionSet.contains)
var needConvert = false
val projectList: Seq[NamedExpression] = plan.output.map {
case p if partitionSet.contains(p) && p.dataType == StringType && p.nullable =>
needConvert = true
Alias(Empty2Null(p), p.name)()
case attr => attr
}
val empty2NullPlan = if (staticPartitionWriteOnly && nativeEnabled) {
// Velox backend only support static partition write.
// And no need to add sort operator for static partition write.
plan
} else {
if (needConvert) ProjectExec(projectList, plan) else plan
}
val writerBucketSpec = bucketSpec.map {
spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
if (
options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, "false") ==
"true"
) {
// Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression.
// Without the extra bitwise-and operation, we can get wrong bucket id when hash value of
// columns is negative. See Hive implementation in
// `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`.
val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue))
val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets))
// The bucket file name prefix is following Hive, Presto and Trino conversion, so this
// makes sure Hive bucketed table written by Spark, can be read by other SQL engines.
//
// Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`.
// Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`.
val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_"
WriterBucketSpec(bucketIdExpression, fileNamePrefix)
} else {
// Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id
// expression, so that we can guarantee the data distribution is same between shuffle and
// bucketed data source, which enables us to only shuffle one side when join a bucketed
// table and a normal one.
val bucketIdExpression =
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
WriterBucketSpec(bucketIdExpression, (_: Int) => "")
}
}
val sortColumns = bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
}
val caseInsensitiveOptions = CaseInsensitiveMap(options)
val dataSchema = dataColumns.toStructType
DataSourceUtils.verifySchema(fileFormat, dataSchema)
// Note: prepareWrite has side effect. It sets "job".
val outputWriterFactory =
fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataSchema)
val description = new WriteJobDescription(
uuid = UUID.randomUUID.toString,
serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
outputWriterFactory = outputWriterFactory,
allColumns = finalOutputSpec.outputColumns,
dataColumns = dataColumns,
partitionColumns = partitionColumns,
bucketSpec = writerBucketSpec,
path = finalOutputSpec.outputPath,
customPartitionLocations = finalOutputSpec.customPartitionLocations,
maxRecordsPerFile = caseInsensitiveOptions
.get("maxRecordsPerFile")
.map(_.toLong)
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile),
timeZoneId = caseInsensitiveOptions
.get(DateTimeUtils.TIMEZONE_OPTION)
.getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone),
statsTrackers = statsTrackers
)
// We should first sort by partition columns, then bucket id, and finally sorting columns.
val requiredOrdering = partitionColumns.drop(numStaticPartitionCols) ++
writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns
// the sort order doesn't matter
val actualOrdering = empty2NullPlan.outputOrdering.map(_.child)
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
false
} else {
requiredOrdering.zip(actualOrdering).forall {
case (requiredOrder, childOutputOrder) =>
requiredOrder.semanticEquals(childOutputOrder)
}
}
SQLExecution.checkSQLExecutionId(sparkSession)
// propagate the description UUID into the jobs, so that committers
// get an ID guaranteed to be unique.
job.getConfiguration.set("spark.sql.sources.writeJobUUID", description.uuid)
// This call shouldn't be put into the `try` block below because it only initializes and
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
committer.setupJob(job)
def nativeWrap(plan: SparkPlan) = {
var wrapped: SparkPlan = plan
if (writerBucketSpec.isDefined) {
// We need to add the bucket id expression to the output of the sort plan,
// so that we can use backend to calculate the bucket id for each row.
wrapped = ProjectExec(
wrapped.output :+ Alias(writerBucketSpec.get.bucketIdExpression, "__bucket_value__")(),
wrapped)
// TODO: to optimize, bucket value is computed twice here
}
val nativeFormat = sparkSession.sparkContext.getLocalProperty("nativeFormat")
(GlutenFormatFactory(nativeFormat).executeWriterWrappedSparkPlan(wrapped), None)
}
try {
val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) {
if (!nativeEnabled || (staticPartitionWriteOnly && nativeEnabled)) {
(empty2NullPlan.execute(), None)
} else {
nativeWrap(empty2NullPlan)
}
} else {
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
// the physical plan may have different attribute ids due to optimizer removing some
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
val orderingExpr = bindReferences(
requiredOrdering.map(SortOrder(_, Ascending)),
finalOutputSpec.outputColumns)
val sortPlan = SortExec(orderingExpr, global = false, child = empty2NullPlan)
val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters
var concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty
if (nativeEnabled && concurrentWritersEnabled) {
log.warn(
s"spark.sql.maxConcurrentOutputFileWriters(being set to $maxWriters) will be " +
"ignored when native writer is being active. No concurrent Writers.")
concurrentWritersEnabled = false
}
if (concurrentWritersEnabled) {
(
empty2NullPlan.execute(),
Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter())))
} else {
if (staticPartitionWriteOnly && nativeEnabled) {
// remove the sort operator for static partition write.
(empty2NullPlan.execute(), None)
} else {
if (!nativeEnabled) {
(sortPlan.execute(), None)
} else {
nativeWrap(sortPlan)
}
}
}
}
// SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
// partition rdd to make sure we at least set up one write task to write the metadata.
val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) {
sparkSession.sparkContext.parallelize(Array.empty[InternalRow], 1)
} else {
rdd
}
val jobIdInstant = new Date().getTime
val ret = new Array[WriteTaskResult](rddWithNonEmptyPartitions.partitions.length)
sparkSession.sparkContext.runJob(
rddWithNonEmptyPartitions,
(taskContext: TaskContext, iter: Iterator[InternalRow]) => {
executeTask(
description = description,
jobIdInstant = jobIdInstant,
sparkStageId = taskContext.stageId(),
sparkPartitionId = taskContext.partitionId(),
sparkAttemptNumber = taskContext.taskAttemptId().toInt & Integer.MAX_VALUE,
committer,
iterator = iter,
concurrentOutputWriterSpec = concurrentOutputWriterSpec
)
},
rddWithNonEmptyPartitions.partitions.indices,
(index, res: WriteTaskResult) => {
committer.onTaskCommit(res.commitMsg)
ret(index) = res
}
)
val commitMsgs = ret.map(_.commitMsg)
logInfo(s"Start to commit write Job ${description.uuid}.")
val (_, duration) = Utils.timeTakenMs(committer.commitJob(job, commitMsgs))
logInfo(s"Write Job ${description.uuid} committed. Elapsed time: $duration ms.")
processStats(description.statsTrackers, ret.map(_.summary.stats), duration)
logInfo(s"Finished processing stats for write job ${description.uuid}.")
// return a set of all the partition paths that were updated during this job
ret.map(_.summary.updatedPartitions).reduceOption(_ ++ _).getOrElse(Set.empty)
} catch {
case cause: Throwable =>
logError(s"Aborting job ${description.uuid}.", cause)
committer.abortJob(job)
throw cause
}
}