in sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala [41:303]
override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] =
new SortMergeJoinEvaluator
private class SortMergeJoinEvaluator extends PartitionEvaluator[InternalRow, InternalRow] {
private def cleanupResources(): Unit = {
IndexedSeq(left, right).foreach(_.cleanupResources())
}
private def createLeftKeyGenerator(): Projection =
UnsafeProjection.create(leftKeys, left.output)
private def createRightKeyGenerator(): Projection =
UnsafeProjection.create(rightKeys, right.output)
override def eval(
partitionIndex: Int,
inputs: Iterator[InternalRow]*): Iterator[InternalRow] = {
assert(inputs.length == 2)
val leftIter = inputs(0)
val rightIter = inputs(1)
val boundCondition: InternalRow => Boolean = {
condition.map { cond =>
Predicate.create(cond, left.output ++ right.output).eval _
}.getOrElse {
(r: InternalRow) => true
}
}
// An ordering that can be used to compare keys from both sides.
val keyOrdering = RowOrdering.createNaturalAscendingOrdering(leftKeys.map(_.dataType))
val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
joinType match {
case _: InnerLike =>
new RowIterator {
private[this] var currentLeftRow: InternalRow = _
private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _
private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
private[this] val smjScanner = new SortMergeJoinScanner(
createLeftKeyGenerator(),
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
spillThreshold,
spillSize,
cleanupResources)
private[this] val joinRow = new JoinedRow
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
rightMatchesIterator = currentRightMatches.generateIterator()
}
override def advanceNext(): Boolean = {
while (rightMatchesIterator != null) {
if (!rightMatchesIterator.hasNext) {
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
rightMatchesIterator = currentRightMatches.generateIterator()
} else {
currentRightMatches = null
currentLeftRow = null
rightMatchesIterator = null
return false
}
}
joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
numOutputRows += 1
return true
}
}
false
}
override def getRow: InternalRow = resultProj(joinRow)
}.toScala
case LeftOuter =>
val smjScanner = new SortMergeJoinScanner(
streamedKeyGenerator = createLeftKeyGenerator(),
bufferedKeyGenerator = createRightKeyGenerator(),
keyOrdering,
streamedIter = RowIterator.fromScala(leftIter),
bufferedIter = RowIterator.fromScala(rightIter),
inMemoryThreshold,
spillThreshold,
spillSize,
cleanupResources)
val rightNullRow = new GenericInternalRow(right.output.length)
new LeftOuterIterator(
smjScanner,
rightNullRow,
boundCondition,
resultProj,
numOutputRows).toScala
case RightOuter =>
val smjScanner = new SortMergeJoinScanner(
streamedKeyGenerator = createRightKeyGenerator(),
bufferedKeyGenerator = createLeftKeyGenerator(),
keyOrdering,
streamedIter = RowIterator.fromScala(rightIter),
bufferedIter = RowIterator.fromScala(leftIter),
inMemoryThreshold,
spillThreshold,
spillSize,
cleanupResources)
val leftNullRow = new GenericInternalRow(left.output.length)
new RightOuterIterator(
smjScanner,
leftNullRow,
boundCondition,
resultProj,
numOutputRows).toScala
case FullOuter =>
val leftNullRow = new GenericInternalRow(left.output.length)
val rightNullRow = new GenericInternalRow(right.output.length)
val smjScanner = new SortMergeFullOuterJoinScanner(
leftKeyGenerator = createLeftKeyGenerator(),
rightKeyGenerator = createRightKeyGenerator(),
keyOrdering,
leftIter = RowIterator.fromScala(leftIter),
rightIter = RowIterator.fromScala(rightIter),
boundCondition,
leftNullRow,
rightNullRow)
new FullOuterIterator(smjScanner, resultProj, numOutputRows).toScala
case LeftSemi =>
new RowIterator {
private[this] var currentLeftRow: InternalRow = _
private[this] val smjScanner = new SortMergeJoinScanner(
createLeftKeyGenerator(),
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
spillThreshold,
spillSize,
cleanupResources,
onlyBufferFirstMatchedRow)
private[this] val joinRow = new JoinedRow
override def advanceNext(): Boolean = {
while (smjScanner.findNextInnerJoinRows()) {
val currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
if (currentRightMatches != null && currentRightMatches.length > 0) {
val rightMatchesIterator = currentRightMatches.generateIterator()
while (rightMatchesIterator.hasNext) {
joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
numOutputRows += 1
return true
}
}
}
}
false
}
override def getRow: InternalRow = currentLeftRow
}.toScala
case LeftAnti =>
new RowIterator {
private[this] var currentLeftRow: InternalRow = _
private[this] val smjScanner = new SortMergeJoinScanner(
createLeftKeyGenerator(),
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
spillThreshold,
spillSize,
cleanupResources,
onlyBufferFirstMatchedRow)
private[this] val joinRow = new JoinedRow
override def advanceNext(): Boolean = {
while (smjScanner.findNextOuterJoinRows()) {
currentLeftRow = smjScanner.getStreamedRow
val currentRightMatches = smjScanner.getBufferedMatches
if (currentRightMatches == null || currentRightMatches.length == 0) {
numOutputRows += 1
return true
}
var found = false
val rightMatchesIterator = currentRightMatches.generateIterator()
while (!found && rightMatchesIterator.hasNext) {
joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
found = true
}
}
if (!found) {
numOutputRows += 1
return true
}
}
false
}
override def getRow: InternalRow = currentLeftRow
}.toScala
case j: ExistenceJoin =>
new RowIterator {
private[this] var currentLeftRow: InternalRow = _
private[this] val result: InternalRow = new GenericInternalRow(Array[Any](null))
private[this] val smjScanner = new SortMergeJoinScanner(
createLeftKeyGenerator(),
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
spillThreshold,
spillSize,
cleanupResources,
onlyBufferFirstMatchedRow)
private[this] val joinRow = new JoinedRow
override def advanceNext(): Boolean = {
while (smjScanner.findNextOuterJoinRows()) {
currentLeftRow = smjScanner.getStreamedRow
val currentRightMatches = smjScanner.getBufferedMatches
var found = false
if (currentRightMatches != null && currentRightMatches.length > 0) {
val rightMatchesIterator = currentRightMatches.generateIterator()
while (!found && rightMatchesIterator.hasNext) {
joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
found = true
}
}
}
result.setBoolean(0, found)
numOutputRows += 1
return true
}
false
}
override def getRow: InternalRow = resultProj(joinRow(currentLeftRow, result))
}.toScala
case x =>
throw new IllegalArgumentException(s"SortMergeJoin should not take $x as the JoinType")
}
}
}