in spark-connector/common/src/main/scala/org/apache/spark/sql/odps/OdpsPartitionReaderFactory.scala [196:352]
override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = {
val conf = broadcastedConf.value.value
// TODO: bearer token refresh
val settings = OdpsClient.builder.config(conf).getOrCreate.getEnvironmentSettings
val reusedBatch = if (asyncRead) false else reusedBatchEnable
val odpsScanPartition = partition.asInstanceOf[OdpsScanPartition]
val readerOptions = ReaderOptions.newBuilder()
.withMaxBatchRowCount(batchSize)
.withSettings(settings)
.withCompressionCodec(codec)
.withReuseBatch(reusedBatch)
.build()
var arrowReader = odpsScanPartition.scan
.createArrowReader(odpsScanPartition.inputSplit, readerOptions)
val schema = odpsScanPartition.scan.readSchema
var inputBytes = 0L
var executor : ExecutorService = null
val asyncQueueForVisit = new DataQueue(asyncReadQueueSize, asyncReadWaitTime)
val DONE_SENTINEL = new Object
if (asyncRead) {
executor = Executors.newSingleThreadExecutor
executor.submit(new Runnable() {
override def run(): Unit = {
try {
while (arrowReader.hasNext) {
asyncQueueForVisit.put(arrowReader.get)
}
asyncQueueForVisit.put(DONE_SENTINEL)
arrowReader.currentMetricsValues.counter(MetricNames.BYTES_COUNT).ifPresent(c =>
inputBytes = c.getCount)
arrowReader.close()
} catch {
case cause: Throwable =>
asyncQueueForVisit.put(cause)
}
}
})
}
new PartitionReader[ColumnarBatch] {
private var columnarBatch: ColumnarBatch = _
private var loadData = false
private def updateColumnBatch(root: VectorSchemaRoot): Unit = {
if (columnarBatch != null && !reusedBatch) {
columnarBatch.close()
}
val vectors = root.getFieldVectors
val fields = root.getSchema.getFields
val fieldNameIdxMap = fields.asScala.map(f => f.getName).zipWithIndex.toMap
if (allNames.nonEmpty) {
val arrowVectors =
allNames.map(name => {
fieldNameIdxMap.get(name) match {
case Some(fieldIdx) =>
new OdpsArrowColumnVector(vectors.get(fieldIdx),
schema.getColumn(name).get().getTypeInfo)
case None =>
throw new RuntimeException("Missing column " + name + " from arrow reader.")
}
}).toList
columnarBatch = new ColumnarBatch(arrowVectors.toArray)
} else {
columnarBatch = new ColumnarBatch(new Array[OdpsArrowColumnVector](0).toArray)
}
columnarBatch.setNumRows(root.getRowCount)
}
override def next(): Boolean = {
if (asyncRead) {
val nextObject = asyncQueueForVisit.take()
if (nextObject == DONE_SENTINEL) {
false
} else nextObject match {
case t: Throwable =>
throw new IOException(t)
case _ =>
updateColumnBatch(nextObject.asInstanceOf[VectorSchemaRoot])
true
}
} else {
try {
if (!arrowReader.hasNext) {
false
} else {
updateColumnBatch(arrowReader.get())
loadData = true
true
}
} catch {
case cause: Throwable =>
val splitIndex = odpsScanPartition.inputSplit match {
case split: InputSplitWithIndex =>
split.getSplitIndex
case split: RowRangeInputSplit =>
split.getRowRange.getStartIndex
case _ => 0
}
val sessionId = odpsScanPartition.inputSplit.getSessionId
logError(s"Partition reader $splitIndex for session $sessionId " +
s"encountered failure ${cause.getMessage}")
if (!loadData) {
if (arrowReader != null) {
arrowReader.close()
}
arrowReader = odpsScanPartition.scan
.createArrowReader(odpsScanPartition.inputSplit, readerOptions)
if (!arrowReader.hasNext) {
false
} else {
updateColumnBatch(arrowReader.get())
loadData = true
true
}
} else {
throw cause
}
}
}
}
override def get(): ColumnarBatch = columnarBatch
override def close(): Unit = {
if (columnarBatch != null) {
columnarBatch.close()
}
if (!asyncRead) {
arrowReader.currentMetricsValues.counter(MetricNames.BYTES_COUNT).ifPresent(c =>
inputBytes = c.getCount)
arrowReader.close()
}
TaskContext.get().taskMetrics().inputMetrics
.incBytesRead(inputBytes)
if (executor != null) {
executor.shutdown()
while (!asyncQueueForVisit.isEmpty) {
val data = asyncQueueForVisit.take()
data match {
case root: VectorSchemaRoot =>
root.close()
case _ =>
}
}
}
}
}
}