in gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala [89:219]
override def asReadOnlyCopy(
broadCastContext: BroadCastHashJoinContext): ColumnarBuildSideRelation = this
/**
* Transform columnar broadcast value to Array[InternalRow] by key and distinct. NOTE: This method
* was called in Spark Driver, should manage resources carefully.
*/
override def transform(key: Expression): Array[InternalRow] = TaskResources.runUnsafe {
// This transformation happens in Spark driver, thus resources can not be managed automatically.
val runtime = Runtimes.contextInstance()
val nativeMemoryManager = NativeMemoryManagers.contextInstance("BuildSideRelation#transform")
val serializerJniWrapper = ColumnarBatchSerializerJniWrapper.create()
val serializeHandle = {
val allocator = ArrowBufferAllocators.contextInstance()
val cSchema = ArrowSchema.allocateNew(allocator)
val arrowSchema = SparkArrowUtil.toArrowSchema(
StructType.fromAttributes(output),
SQLConf.get.sessionLocalTimeZone)
ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema)
val handle = serializerJniWrapper
.init(cSchema.memoryAddress(), nativeMemoryManager.getNativeInstanceHandle)
cSchema.close()
handle
}
var closed = false
// Convert columnar to Row.
val jniWrapper = NativeColumnarToRowJniWrapper.create()
val c2rId = jniWrapper.nativeColumnarToRowInit(nativeMemoryManager.getNativeInstanceHandle)
var batchId = 0
val iterator = if (batches.length > 0) {
val res: Iterator[Iterator[InternalRow]] = new Iterator[Iterator[InternalRow]] {
override def hasNext: Boolean = {
val itHasNext = batchId < batches.length
if (!itHasNext && !closed) {
jniWrapper.nativeClose(c2rId)
serializerJniWrapper.close(serializeHandle)
closed = true
}
itHasNext
}
override def next(): Iterator[InternalRow] = {
val batchBytes = batches(batchId)
batchId += 1
val batchHandle =
serializerJniWrapper.deserialize(serializeHandle, batchBytes)
val batch = ColumnarBatches.create(runtime, batchHandle)
if (batch.numRows == 0) {
batch.close()
Iterator.empty
} else if (output.isEmpty) {
val rows = ColumnarBatches.emptyRowIterator(batch.numRows()).asScala
batch.close()
rows
} else {
val cols = batch.numCols()
val rows = batch.numRows()
val info =
jniWrapper.nativeColumnarToRowConvert(batchHandle, c2rId)
batch.close()
val columnNames = key.flatMap {
case expression: AttributeReference =>
Some(expression)
case _ =>
None
}
if (columnNames.isEmpty) {
throw new IllegalArgumentException(s"Key column not found in expression: $key")
}
if (columnNames.size != 1) {
throw new IllegalArgumentException(s"Multiple key columns found in expression: $key")
}
val columnExpr = columnNames.head
val oneColumnWithSameName = output.count(_.name == columnExpr.name) == 1
val columnInOutput = output.zipWithIndex.filter {
p: (Attribute, Int) =>
if (oneColumnWithSameName) {
// The comparison of exprId can be ignored when
// only one attribute name match is found.
p._1.name == columnExpr.name
} else {
// A case where output has multiple columns with same name
p._1.name == columnExpr.name && p._1.exprId == columnExpr.exprId
}
}
if (columnInOutput.isEmpty) {
throw new IllegalStateException(
s"Key $key not found from build side relation output: $output")
}
if (columnInOutput.size != 1) {
throw new IllegalStateException(
s"More than one key $key found from build side relation output: $output")
}
val replacement =
BoundReference(columnInOutput.head._2, columnExpr.dataType, columnExpr.nullable)
val projExpr = key.transformDown {
case _: AttributeReference =>
replacement
}
val proj = UnsafeProjection.create(projExpr)
new Iterator[InternalRow] {
var rowId = 0
val row = new UnsafeRow(cols)
override def hasNext: Boolean = {
rowId < rows
}
override def next: UnsafeRow = {
if (rowId >= rows) throw new NoSuchElementException
val (offset, length) = (info.offsets(rowId), info.lengths(rowId))
row.pointTo(null, info.memoryAddress + offset, length.toInt)
rowId += 1
row
}
}.map(proj).map(_.copy())
}
}
}
res.flatten
} else {
Iterator.empty
}
iterator.toArray
}