def toColumnarBatchIterator()

in backends-velox/src/main/scala/io/glutenproject/execution/RowToVeloxColumnarExec.scala [81:213]


  def toColumnarBatchIterator(
      it: Iterator[InternalRow],
      schema: StructType,
      numInputRows: SQLMetric,
      numOutputBatches: SQLMetric,
      convertTime: SQLMetric,
      columnBatchSize: Int): Iterator[ColumnarBatch] = {
    if (it.isEmpty) {
      return Iterator.empty
    }

    val arrowSchema =
      SparkArrowUtil.toArrowSchema(schema, SQLConf.get.sessionLocalTimeZone)
    val jniWrapper = NativeRowToColumnarJniWrapper.create()
    val allocator = ArrowBufferAllocators.contextInstance()
    val cSchema = ArrowSchema.allocateNew(allocator)
    val r2cHandle =
      try {
        ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema)
        jniWrapper.init(
          cSchema.memoryAddress(),
          NativeMemoryManagers
            .contextInstance("RowToColumnar")
            .getNativeInstanceHandle)
      } finally {
        cSchema.close()
      }

    val res: Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] {
      var finished = false

      override def hasNext: Boolean = {
        if (finished) {
          false
        } else {
          it.hasNext
        }
      }

      def nativeConvert(row: UnsafeRow): ColumnarBatch = {
        var arrowBuf: ArrowBuf = null
        TaskResources.addRecycler("RowToColumnar_arrowBuf", 100) {
          // Remind, remove isOpen here
          if (arrowBuf != null && arrowBuf.refCnt() != 0) {
            arrowBuf.close()
          }
        }
        val rowLength = new ListBuffer[Long]()
        var rowCount = 0
        var offset = 0
        val sizeInBytes = row.getSizeInBytes
        // allocate buffer based on 1st row, but if first row is very big, this will cause OOM
        // maybe we should optimize to list ArrayBuf to native to avoid buf close and allocate
        // 31760L origins from BaseVariableWidthVector.lastValueAllocationSizeInBytes
        // experimental value
        val estimatedBufSize = Math.max(
          Math.min(sizeInBytes.toDouble * columnBatchSize * 1.2, 31760L * columnBatchSize),
          sizeInBytes.toDouble * 10)
        arrowBuf = allocator.buffer(estimatedBufSize.toLong)
        Platform.copyMemory(
          row.getBaseObject,
          row.getBaseOffset,
          null,
          arrowBuf.memoryAddress() + offset,
          sizeInBytes)
        offset += sizeInBytes
        rowLength += sizeInBytes.toLong
        rowCount += 1

        while (rowCount < columnBatchSize && !finished) {
          val iterHasNext = it.hasNext
          if (!iterHasNext) {
            finished = true
          } else {
            val row = it.next()
            val unsafeRow = convertToUnsafeRow(row)
            val sizeInBytes = unsafeRow.getSizeInBytes
            if ((offset + sizeInBytes) > arrowBuf.capacity()) {
              val tmpBuf = allocator.buffer(((offset + sizeInBytes) * 2).toLong)
              tmpBuf.setBytes(0, arrowBuf, 0, offset)
              arrowBuf.close()
              arrowBuf = tmpBuf
            }
            Platform.copyMemory(
              unsafeRow.getBaseObject,
              unsafeRow.getBaseOffset,
              null,
              arrowBuf.memoryAddress() + offset,
              sizeInBytes)
            offset += sizeInBytes
            rowLength += sizeInBytes.toLong
            rowCount += 1
          }
        }
        numInputRows += rowCount
        try {
          val handle = jniWrapper
            .nativeConvertRowToColumnar(r2cHandle, rowLength.toArray, arrowBuf.memoryAddress())
          ColumnarBatches.create(Runtimes.contextInstance(), handle)
        } finally {
          arrowBuf.close()
          arrowBuf = null
        }
      }

      def convertToUnsafeRow(row: InternalRow): UnsafeRow = {
        row match {
          case unsafeRow: UnsafeRow => unsafeRow
          case _ =>
            val factory = UnsafeProjection
            val converter = factory.create(schema)
            converter.apply(row)
        }
      }

      override def next(): ColumnarBatch = {
        val firstRow = it.next()
        val start = System.currentTimeMillis()
        val unsafeRow = convertToUnsafeRow(firstRow)
        val cb = nativeConvert(unsafeRow)
        numOutputBatches += 1
        convertTime += System.currentTimeMillis() - start
        cb
      }
    }
    Iterators
      .wrap(res)
      .recycleIterator {
        jniWrapper.close(r2cHandle)
      }
      .recyclePayload(_.close())
      .create()
  }