in gluten-core/src/main/scala/io/glutenproject/execution/JoinUtils.scala [176:289]
def createJoinRel(
streamedKeyExprs: Seq[Expression],
buildKeyExprs: Seq[Expression],
condition: Option[Expression],
substraitJoinType: JoinRel.JoinType,
exchangeTable: Boolean,
joinType: JoinType,
joinParameters: Any,
inputStreamedRelNode: RelNode,
inputBuildRelNode: RelNode,
inputStreamedOutput: Seq[Attribute],
inputBuildOutput: Seq[Attribute],
substraitContext: SubstraitContext,
operatorId: java.lang.Long,
validation: Boolean = false): RelNode = {
// scalastyle:on argcount
// Create pre-projection for build/streamed plan. Append projected keys to each side.
val (streamedKeys, streamedRelNode, streamedOutput) = createPreProjectionIfNeeded(
streamedKeyExprs,
inputStreamedRelNode,
inputStreamedOutput,
inputStreamedOutput,
substraitContext,
operatorId,
validation)
val (buildKeys, buildRelNode, buildOutput) = createPreProjectionIfNeeded(
buildKeyExprs,
inputBuildRelNode,
inputBuildOutput,
streamedOutput ++ inputBuildOutput,
substraitContext,
operatorId,
validation)
// Combine join keys to make a single expression.
val joinExpressionNode = (streamedKeys
.zip(buildKeys))
.map {
case ((leftKey, leftType), (rightKey, rightType)) =>
HashJoinLikeExecTransformer.makeEqualToExpression(
leftKey,
leftType,
rightKey,
rightType,
substraitContext.registeredFunction)
}
.reduce(
(l, r) =>
HashJoinLikeExecTransformer.makeAndExpression(l, r, substraitContext.registeredFunction))
// Create post-join filter, which will be computed in hash join.
val postJoinFilter = condition.map {
expr =>
ExpressionConverter
.replaceWithExpressionTransformer(expr, streamedOutput ++ buildOutput)
.doTransform(substraitContext.registeredFunction)
}
// Create JoinRel.
val joinRel = RelBuilder.makeJoinRel(
streamedRelNode,
buildRelNode,
substraitJoinType,
joinExpressionNode,
postJoinFilter.orNull,
createJoinExtensionNode(joinParameters, streamedOutput ++ buildOutput),
substraitContext,
operatorId
)
// Result projection will drop the appended keys, and exchange columns order if BuildLeft.
val resultProjection = if (exchangeTable) {
val (leftOutput, rightOutput) =
getDirectJoinOutput(joinType, inputBuildOutput, inputStreamedOutput)
joinType match {
case _: ExistenceJoin =>
inputBuildOutput.indices.map(ExpressionBuilder.makeSelection(_)) :+
ExpressionBuilder.makeSelection(buildOutput.size)
case LeftExistence(_) =>
leftOutput.indices.map(ExpressionBuilder.makeSelection(_))
case _ =>
// Exchange the order of build and streamed.
leftOutput.indices.map(
idx => ExpressionBuilder.makeSelection(idx + streamedOutput.size)) ++
rightOutput.indices
.map(ExpressionBuilder.makeSelection(_))
}
} else {
val (leftOutput, rightOutput) =
getDirectJoinOutput(joinType, inputStreamedOutput, inputBuildOutput)
if (joinType.isInstanceOf[ExistenceJoin]) {
inputStreamedOutput.indices.map(ExpressionBuilder.makeSelection(_)) :+
ExpressionBuilder.makeSelection(streamedOutput.size)
} else {
leftOutput.indices.map(ExpressionBuilder.makeSelection(_)) ++
rightOutput.indices.map(idx => ExpressionBuilder.makeSelection(idx + streamedOutput.size))
}
}
val directJoinOutputs = if (exchangeTable) {
getDirectJoinOutputSeq(joinType, buildOutput, streamedOutput)
} else {
getDirectJoinOutputSeq(joinType, streamedOutput, buildOutput)
}
RelBuilder.makeProjectRel(
joinRel,
new java.util.ArrayList[ExpressionNode](resultProjection.asJava),
createExtensionNode(directJoinOutputs, validation),
substraitContext,
operatorId,
directJoinOutputs.size
)
}