in spark-connector/common/src/main/scala/org/apache/spark/sql/odps/OdpsPartitionReaderFactory.scala [71:194]
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
if (output.isEmpty) {
assert(partition.isInstanceOf[OdpsEmptyColumnPartition], "Output column is empty")
val emptyColumnPartition = partition.asInstanceOf[OdpsEmptyColumnPartition]
return new PartitionReader[InternalRow] {
private val unsafeRow: InternalRow = InternalRow()
private var count = 0L
override def next(): Boolean = {
if (count < emptyColumnPartition.rowCount) {
count = count + 1
true
} else {
false
}
}
override def get(): InternalRow = unsafeRow
override def close(): Unit = {
}
}
}
val conf = broadcastedConf.value.value
val settings = OdpsClient.builder.config(conf).getOrCreate.getEnvironmentSettings
val odpsScanPartition = partition.asInstanceOf[OdpsScanPartition]
val supportRecordReader = odpsScanPartition.scan.supportsDataFormat(recordDataFormat)
if (supportRecordReader) {
val readerOptions = ReaderOptions.newBuilder()
.withSettings(settings)
.build()
val recordReader = odpsScanPartition.scan
.createRecordReader(odpsScanPartition.inputSplit, readerOptions)
val readTypeInfos = odpsScanPartition.scan.readSchema.getColumns.asScala.map(_.getTypeInfo)
new PartitionReader[InternalRow] {
private val converters = readTypeInfos.map(OdpsUtils.odpsData2SparkData)
private val currentRow = {
val row = new SpecificInternalRow(allTypes)
row
}
private val unsafeProjection = GenerateUnsafeProjection.generate(output, output)
private var unsafeRow: UnsafeRow = _
override def next(): Boolean = {
if (!recordReader.hasNext) {
false
} else {
val record = recordReader.get()
var i = 0
if (record ne null) {
while (i < converters.length) {
val value = record.get(i)
if (value ne null) {
currentRow.update(i, converters(i)(value))
} else {
currentRow.setNullAt(i)
}
i += 1
}
} else {
while (i < allTypes.length) {
currentRow.setNullAt(i)
i += 1
}
}
unsafeRow = unsafeProjection(currentRow)
true
}
}
override def get(): InternalRow = unsafeRow
override def close(): Unit = {
recordReader.currentMetricsValues.counter(MetricNames.BYTES_COUNT).ifPresent(c =>
TaskContext.get().taskMetrics().inputMetrics
.incBytesRead(c.getCount))
recordReader.close()
}
}
} else {
val supportArrowReader = odpsScanPartition.scan.supportsDataFormat(arrowDataFormat)
if (supportArrowReader) {
new PartitionReader[InternalRow] {
private var unsafeRow: InternalRow = _
private val batchReader = createColumnarReader(partition)
private var rowIterator: Iterator[InternalRow] = _
private def hasNext: Boolean = {
if (rowIterator == null || !rowIterator.hasNext) {
if (batchReader.next) {
val batch = batchReader.get
if (batch != null) {
rowIterator = batch.rowIterator.asScala
} else {
rowIterator = null
}
} else {
rowIterator = null
}
}
rowIterator != null && rowIterator.hasNext
}
override def next(): Boolean = {
if (!hasNext) {
false
} else {
unsafeRow = rowIterator.next()
true
}
}
override def get(): InternalRow = unsafeRow
override def close(): Unit = {
batchReader.close()
}
}
} else {
throw new UnsupportedOperationException(
"Table provider unsupported record/arrow data format")
}
}
}