in spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala [54:106]
def write(dataFrame: DataFrame): Unit = {
val sc = dataFrame.sqlContext.sparkContext
val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc")
if (enable2PC) {
sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader))
}
var resultRdd = dataFrame.rdd
val dfColumns = dataFrame.columns
if (Objects.nonNull(sinkTaskPartitionSize)) {
resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize)
}
resultRdd
.map(_.toSeq.map(_.asInstanceOf[AnyRef]).toList.asJava)
.foreachPartition(partition => {
partition
.grouped(batchSize)
.foreach(batch => flush(batch, dfColumns))
})
/**
* flush data to Doris and do retry when flush error
*
*/
def flush(batch: Iterable[util.List[Object]], dfColumns: Array[String]): Unit = {
Utils.retry[util.List[Integer], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) {
dorisStreamLoader.loadV2(batch.toList.asJava, dfColumns, enable2PC)
} match {
case Success(txnIds) => if (enable2PC) txnIds.asScala.foreach(txnId => preCommittedTxnAcc.add(txnId))
case Failure(e) =>
if (enable2PC) {
// if task run failed, acc value will not be returned to driver,
// should abort all pre committed transactions inside the task
logger.info("load task failed, start aborting previously pre-committed transactions")
val abortFailedTxnIds = mutable.Buffer[Int]()
preCommittedTxnAcc.value.asScala.foreach(txnId => {
Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) {
dorisStreamLoader.abort(txnId)
} match {
case Success(_) =>
case Failure(_) => abortFailedTxnIds += txnId
}
})
if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", abortFailedTxnIds.mkString(","))
preCommittedTxnAcc.reset()
}
throw new IOException(
s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e)
}
}
}