in backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala [472:584]
override def createBroadcastRelation(
mode: BroadcastMode,
child: SparkPlan,
numOutputRows: SQLMetric,
dataSize: SQLMetric): BuildSideRelation = {
val (buildKeys, isNullAware) = mode match {
case mode1: HashedRelationBroadcastMode =>
(mode1.key, mode1.isNullAware)
case _ =>
// IdentityBroadcastMode
(Seq.empty, false)
}
val (newChild, newOutput, newBuildKeys) =
if (
buildKeys
.forall(k => k.isInstanceOf[AttributeReference] || k.isInstanceOf[BoundReference])
) {
(child, child.output, Seq.empty[Expression])
} else {
// pre projection in case of expression join keys
val appendedProjections = new ArrayBuffer[NamedExpression]()
val preProjectionBuildKeys = buildKeys.zipWithIndex.map {
case (e, idx) =>
e match {
case b: BoundReference => child.output(b.ordinal)
case o: Expression =>
val newExpr = Alias(o, "col_" + idx)()
appendedProjections += newExpr
newExpr
}
}
def wrapChild(child: SparkPlan): WholeStageTransformer = {
val childWithAdapter = ColumnarCollapseTransformStages.wrapInputIteratorTransformer(child)
WholeStageTransformer(
ProjectExecTransformer(child.output ++ appendedProjections, childWithAdapter))(
ColumnarCollapseTransformStages.transformStageCounter.incrementAndGet()
)
}
val newChild = child match {
case wt: WholeStageTransformer =>
wt.withNewChildren(
Seq(ProjectExecTransformer(child.output ++ appendedProjections, wt.child)))
case w: WholeStageCodegenExec =>
w.withNewChildren(Seq(ProjectExec(child.output ++ appendedProjections, w.child)))
case r: AQEShuffleReadExec if r.supportsColumnar =>
// when aqe is open
// TODO: remove this after pushdowning preprojection
wrapChild(r)
case r2c: RowToCHNativeColumnarExec =>
wrapChild(r2c)
case union: ColumnarUnionExec =>
wrapChild(union)
case ordered: TakeOrderedAndProjectExecTransformer =>
wrapChild(ordered)
case rddScan: CHRDDScanTransformer =>
wrapChild(rddScan)
case other =>
throw new GlutenNotSupportException(
s"Not supported operator ${other.nodeName} for BroadcastRelation")
}
(newChild, (child.output ++ appendedProjections).map(_.toAttribute), preProjectionBuildKeys)
}
// find the key index in the output
val keyColumnIndex = if (isNullAware) {
def findKeyOrdinal(key: Expression, output: Seq[Attribute]): Int = {
key match {
case b: BoundReference => b.ordinal
case n: NamedExpression =>
output.indexWhere(o => (o.name.equals(n.name) && o.exprId == n.exprId))
case _ => throw new GlutenException(s"Cannot find $key in the child's output: $output")
}
}
if (newBuildKeys.isEmpty) {
findKeyOrdinal(buildKeys(0), newOutput)
} else {
findKeyOrdinal(newBuildKeys(0), newOutput)
}
} else {
0
}
val countsAndBytes =
CHExecUtil.buildSideRDD(dataSize, newChild, isNullAware, keyColumnIndex).collect
val batches = countsAndBytes.map(_._2)
val totalBatchesSize = batches.map(_.length).sum
val rawSize = dataSize.value
if (rawSize >= GlutenConfig.get.maxBroadcastTableSize) {
throw new GlutenException(
"Cannot broadcast the table that is larger than " +
s"${SparkMemoryUtil.bytesToString(GlutenConfig.get.maxBroadcastTableSize)}: " +
s"${SparkMemoryUtil.bytesToString(rawSize)}")
}
if ((rawSize == 0 && totalBatchesSize != 0) || totalBatchesSize < 0) {
throw new GlutenException(
s"Invalid rawSize($rawSize) or totalBatchesSize ($totalBatchesSize). Ensure the shuffle" +
s" written bytes is correct.")
}
val rowCount = countsAndBytes.map(_._1).sum
val hasNullKeyValues = countsAndBytes.map(_._3).foldLeft[Boolean](false)((b, a) => { b || a })
numOutputRows += rowCount
ClickHouseBuildSideRelation(
mode,
newOutput,
batches.flatten,
rowCount,
newBuildKeys,
hasNullKeyValues)
}