in spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala [213:301]
override protected def doProduce(ctx: CodegenContext): String = {
// PhysicalRDD always just has one input
val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];")
// metrics
val numOutputRows = metricTerm(ctx, "numOutputRows")
val numInputBatches = metricTerm(ctx, "numInputBatches")
val columnarBatchClz = classOf[ColumnarBatch].getName
val batch = ctx.addMutableState(columnarBatchClz, "batch")
val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0
val columnVectorClzs =
child.vectorTypes.getOrElse(Seq.fill(output.indices.size)(classOf[ColumnVector].getName))
val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map {
case (columnVectorClz, i) =>
val name = ctx.addMutableState(columnVectorClz, s"colInstance$i")
(name, s"$name = ($columnVectorClz) $batch.column($i);")
}.unzip
val nextBatch = ctx.freshName("nextBatch")
val nextBatchFuncName = ctx.addNewFunction(
nextBatch,
s"""
|private void $nextBatch() throws java.io.IOException {
| if ($input.hasNext()) {
| $batch = ($columnarBatchClz)$input.next();
| $numInputBatches.add(1);
| $numOutputRows.add($batch.numRows());
| $idx = 0;
| ${columnAssigns.mkString("", "\n", "\n")}
| }
|}""".stripMargin)
ctx.currentVars = null
val rowidx = ctx.freshName("rowIdx")
val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
}
val localIdx = ctx.freshName("localIdx")
val localEnd = ctx.freshName("localEnd")
val numRows = ctx.freshName("numRows")
val shouldStop = if (parent.needStopCheck) {
s"if (shouldStop()) { $idx = $rowidx + 1; return; }"
} else {
"// shouldStop check is eliminated"
}
val writableColumnVectorClz = classOf[WritableColumnVector].getName
val constantColumnVectorClz = classOf[ConstantColumnVector].getName
val cometPlainColumnVectorClz = classOf[CometPlainVector].getName
// scalastyle:off line.size.limit
s"""
|if ($batch == null) {
| $nextBatchFuncName();
|}
|while ($limitNotReachedCond $batch != null) {
| int $numRows = $batch.numRows();
| int $localEnd = $numRows - $idx;
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
| int $rowidx = $idx + $localIdx;
| ${consume(ctx, columnsBatchInput).trim}
| $shouldStop
| }
| $idx = $numRows;
|
| // Comet fix for SPARK-50235
| for (int i = 0; i < ${colVars.length}; i++) {
| if (!($batch.column(i) instanceof $writableColumnVectorClz || $batch.column(i) instanceof $constantColumnVectorClz || $batch.column(i) instanceof $cometPlainColumnVectorClz)) {
| $batch.column(i).close();
| } else if ($batch.column(i) instanceof $cometPlainColumnVectorClz) {
| $cometPlainColumnVectorClz cometPlainColumnVector = ($cometPlainColumnVectorClz) $batch.column(i);
| if (!cometPlainColumnVector.isReused()) {
| cometPlainColumnVector.close();
| }
| }
| }
|
| $batch = null;
| $nextBatchFuncName();
|}
|// Comet fix for SPARK-50235: clean up resources
|if ($batch != null) {
| $batch.close();
|}
""".stripMargin
// scalastyle:on line.size.limit
}