in src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcWriterUtils.scala [41:99]
def saveTable(
df: DataFrame,
url: String,
table: String,
textColumns: Seq[String],
postWriteSql: String,
retry: Boolean,
writesPerSecond: Double,
properties: Properties) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
}
val rddSchema = df.schema
val columnNames = rddSchema.fields.map(_.name).toSeq.asJava
val mapPartitionsResult = df.rdd.mapPartitions { iterator =>
logInfo(s"UberJdbcUtils.saveTable: $table")
val rateLimiter = RateLimiter.create(writesPerSecond);
try {
val singleTableJdbcWriter = new SingleTableJdbcWriter(url, table, null, textColumns.asJava);
val atomicLong = new AtomicLong()
iterator.foreach(row => {
atomicLong.incrementAndGet()
val columnValues = row.toSeq.map(_.asInstanceOf[AnyRef]).asJava
if (retry) {
val retryPolicy: ExponentialBackoffRetryPolicy[String] = new ExponentialBackoffRetryPolicy[String](3, 100L)
retryPolicy.attempt(new Callable[String] {
override def call(): String = {
rateLimiter.acquire()
singleTableJdbcWriter.writeColumns(columnNames, columnValues)
return null
}
})
} else {
singleTableJdbcWriter.writeColumns(columnNames, columnValues)
}
})
Seq(atomicLong.longValue()).iterator
} catch {
case e: Throwable => {
throw e
}
}
}
val rowCounts = mapPartitionsResult.collect()
val totalSavedRows = rowCounts.sum
logInfo(s"Saved $totalSavedRows rows to jdbc table $table")
if (postWriteSql != null && !postWriteSql.isEmpty) {
logInfo(s"Running post save sql: $postWriteSql")
SqlUtils.executeJdbcUpdate(url, postWriteSql)
}
}