in client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala [98:172]
override def deserializeStream(in: InputStream): DeserializationStream = {
val numFields = schema.fields.length
new DeserializationStream {
val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in))
val EOF: Int = -1
var colBuffer: Array[Byte] = new Array[Byte](1024)
var numRows: Int = readSize()
var rowIter: Iterator[InternalRow] = if (numRows != EOF) nextBatch() else Iterator.empty
override def asKeyValueIterator: Iterator[(Int, InternalRow)] = {
new Iterator[(Int, InternalRow)] {
override def hasNext: Boolean = rowIter.hasNext || {
if (numRows != EOF) {
rowIter = nextBatch()
true
} else {
false
}
}
override def next(): (Int, InternalRow) = {
(0, rowIter.next())
}
}
}
override def asIterator: Iterator[Any] = {
throw new UnsupportedOperationException
}
override def readObject[T: ClassTag](): T = {
throw new UnsupportedOperationException
}
def nextBatch(): Iterator[InternalRow] = {
val columnVectors =
if (!offHeapColumnVectorEnabled) {
OnHeapColumnVector.allocateColumns(numRows, schema)
} else {
OffHeapColumnVector.allocateColumns(numRows, schema)
}
val columnarBatch = new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]])
columnarBatch.setNumRows(numRows)
for (i <- 0 until numFields) {
val colLen: Int = readSize()
if (colBuffer.length < colLen) {
colBuffer = new Array[Byte](colLen)
}
ByteStreams.readFully(dIn, colBuffer, 0, colLen)
CelebornColumnAccessor.decompress(
colBuffer,
columnarBatch.column(i).asInstanceOf[WritableColumnVector],
schema.fields(i).dataType,
numRows)
}
numRows = readSize()
columnarBatch.rowIterator().asScala.map(toUnsafe)
}
def readSize(): Int =
try {
dIn.readInt()
} catch {
case e: EOFException =>
dIn.close()
EOF
}
override def close(): Unit = {
dIn.close()
}
}
}