in connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraCoGroupedRDD.scala [195:253]
override def compute(split: Partition, context: TaskContext): Iterator[Seq[Seq[T]]] = {
/** Open two sessions if Cluster Configurations are different **/
def openSession(rdd: CassandraTableScanRDD[T]): CqlSession = {
if (connector == rdd.connector) {
connector.openSession()
} else {
rdd.connector.openSession
}
}
def closeSessions(sessions: Seq[CqlSession]): Unit = {
for (s<-sessions) {
if (!s.isClosed) s.close()
}
}
val rddWithSessions: Seq[(CassandraTableScanRDD[T], CqlSession)] = scanRDDs.map (rdd => (rdd, openSession(rdd)))
type V = t forSome { type t }
type K = t forSome { type t <: com.datastax.spark.connector.rdd.partitioner.dht.Token[V] }
val partition = split.asInstanceOf[CassandraPartition[V, K]]
val tokenRanges = partition.tokenRanges
val metricsReadConf = new ReadConf(taskMetricsEnabled = scanRDDs.exists(_.readConf.taskMetricsEnabled))
val metricsUpdater = InputMetricsUpdater(context, metricsReadConf)
val mergingIterator: Iterator[Seq[Seq[T]]] = tokenRanges.iterator.flatMap { tokenRange =>
val rowsWithMeta =
rddWithSessions.map { case (rdd, session) => fetchTokenRange(session, rdd, tokenRange, metricsUpdater) }
val metaData = rowsWithMeta.map(_._1)
val rows = rowsWithMeta.map(_._2)
val rowMerger = new MultiMergeJoinIterator[Row, Token](
rows,
tokenExtractor
)
rowMerger.map ((allGroups: Seq[Seq[Row]]) => {
allGroups.zip(metaData).zip(scanRDDs).map { case ((rows, meta), rdd) =>
convertRowSeq(rows, rdd.rowReader, meta)
}
})
}
val countingIterator = new CountingIterator(mergingIterator)
context.addTaskCompletionListener { (context) =>
val duration = metricsUpdater.finish() / 1000000000d
logDebug(
f"""Fetched ${countingIterator.count} rows from
|${scanRDDs.head.keyspaceName}
|for partition ${partition.index} in $duration%.3f s.""".stripMargin)
closeSessions(rddWithSessions.map(_._2))
context
}
countingIterator
}