override def createReader()

in spark-connector/common/src/main/scala/org/apache/spark/sql/odps/OdpsPartitionReaderFactory.scala [71:194]


  override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
    if (output.isEmpty) {
      assert(partition.isInstanceOf[OdpsEmptyColumnPartition], "Output column is empty")
      val emptyColumnPartition = partition.asInstanceOf[OdpsEmptyColumnPartition]
      return new PartitionReader[InternalRow] {
        private val unsafeRow: InternalRow = InternalRow()
        private var count = 0L
        override def next(): Boolean = {
          if (count < emptyColumnPartition.rowCount) {
            count = count + 1
            true
          } else {
            false
          }
        }
        override def get(): InternalRow = unsafeRow
        override def close(): Unit = {
        }
      }
    }

    val conf = broadcastedConf.value.value
    val settings = OdpsClient.builder.config(conf).getOrCreate.getEnvironmentSettings
    val odpsScanPartition = partition.asInstanceOf[OdpsScanPartition]
    val supportRecordReader = odpsScanPartition.scan.supportsDataFormat(recordDataFormat)
    if (supportRecordReader) {
      val readerOptions = ReaderOptions.newBuilder()
        .withSettings(settings)
        .build()

      val recordReader = odpsScanPartition.scan
        .createRecordReader(odpsScanPartition.inputSplit, readerOptions)
      val readTypeInfos = odpsScanPartition.scan.readSchema.getColumns.asScala.map(_.getTypeInfo)

      new PartitionReader[InternalRow] {
        private val converters = readTypeInfos.map(OdpsUtils.odpsData2SparkData)
        private val currentRow = {
          val row = new SpecificInternalRow(allTypes)
          row
        }
        private val unsafeProjection = GenerateUnsafeProjection.generate(output, output)
        private var unsafeRow: UnsafeRow = _

        override def next(): Boolean = {
          if (!recordReader.hasNext) {
            false
          } else {
            val record = recordReader.get()
            var i = 0
            if (record ne null) {
              while (i < converters.length) {
                val value = record.get(i)
                if (value ne null) {
                  currentRow.update(i, converters(i)(value))
                } else {
                  currentRow.setNullAt(i)
                }
                i += 1
              }
            } else {
              while (i < allTypes.length) {
                currentRow.setNullAt(i)
                i += 1
              }
            }
            unsafeRow = unsafeProjection(currentRow)
            true
          }
        }

        override def get(): InternalRow = unsafeRow

        override def close(): Unit = {
          recordReader.currentMetricsValues.counter(MetricNames.BYTES_COUNT).ifPresent(c =>
            TaskContext.get().taskMetrics().inputMetrics
              .incBytesRead(c.getCount))
          recordReader.close()
        }
      }
    } else {
      val supportArrowReader = odpsScanPartition.scan.supportsDataFormat(arrowDataFormat)
      if (supportArrowReader) {
        new PartitionReader[InternalRow] {
          private var unsafeRow: InternalRow = _
          private val batchReader = createColumnarReader(partition)
          private var rowIterator: Iterator[InternalRow] = _

          private def hasNext: Boolean = {
            if (rowIterator == null || !rowIterator.hasNext) {
              if (batchReader.next) {
                val batch = batchReader.get
                if (batch != null) {
                  rowIterator = batch.rowIterator.asScala
                } else {
                  rowIterator = null
                }
              } else {
                rowIterator = null
              }
            }
            rowIterator != null && rowIterator.hasNext
          }

          override def next(): Boolean = {
            if (!hasNext) {
              false
            } else {
              unsafeRow = rowIterator.next()
              true
            }
          }

          override def get(): InternalRow = unsafeRow

          override def close(): Unit = {
            batchReader.close()
          }
        }
      } else {
        throw new UnsupportedOperationException(
          "Table provider unsupported record/arrow data format")
      }
    }
  }