in sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala [209:418]
private[this] def createCosmosAsyncClient(cosmosClientConfiguration: CosmosClientConfiguration,
cosmosClientStateHandle: Option[CosmosClientMetadataCachesSnapshot]): CosmosAsyncClient = {
if (cosmosClientConfiguration.enforceNativeTransport && !io.netty.channel.epoll.Epoll.isAvailable) {
throw new IllegalStateException(
"The enforcement of native transport is enabled in your configuration and native transport is not " +
"available. Either ensure `spark.cosmos.enforceNativeTransport` is set to false or make " +
"sure you use a Spark environment supporting native transport.",
io.netty.channel.epoll.Epoll.unavailabilityCause()
)
}
var builder = new CosmosClientBuilder()
.endpoint(cosmosClientConfiguration.endpoint)
.userAgentSuffix(cosmosClientConfiguration.applicationName)
.throttlingRetryOptions(
new ThrottlingRetryOptions()
.setMaxRetryAttemptsOnThrottledRequests(Int.MaxValue)
.setMaxRetryWaitTime(Duration.ofSeconds((Integer.MAX_VALUE / 1000) - 1)))
val authConfig = cosmosClientConfiguration.authConfig
authConfig match {
case masterKeyAuthConfig: CosmosMasterKeyAuthConfig => builder.key(masterKeyAuthConfig.accountKey)
case servicePrincipalAuthConfig: CosmosServicePrincipalAuthConfig =>
val tokenCredential = if (servicePrincipalAuthConfig.clientCertPemBase64.isDefined) {
val certInputStream = new ByteArrayInputStream(Base64.getDecoder.decode(servicePrincipalAuthConfig.clientCertPemBase64.get))
new ClientCertificateCredentialBuilder()
.authorityHost(new AzureEnvironment(cosmosClientConfiguration.azureEnvironmentEndpoints).getActiveDirectoryEndpoint())
.tenantId(servicePrincipalAuthConfig.tenantId)
.clientId(servicePrincipalAuthConfig.clientId)
.pemCertificate(certInputStream)
.sendCertificateChain(servicePrincipalAuthConfig.sendChain)
.build()
} else {
new ClientSecretCredentialBuilder()
.authorityHost(new AzureEnvironment(cosmosClientConfiguration.azureEnvironmentEndpoints).getActiveDirectoryEndpoint())
.tenantId(servicePrincipalAuthConfig.tenantId)
.clientId(servicePrincipalAuthConfig.clientId)
.clientSecret(servicePrincipalAuthConfig.clientSecret.get)
.build()
}
builder.credential(tokenCredential)
case managedIdentityAuthConfig: CosmosManagedIdentityAuthConfig =>
builder.credential(createTokenCredential(managedIdentityAuthConfig))
case accessTokenAuthConfig: CosmosAccessTokenAuthConfig =>
builder.credential(createTokenCredential(accessTokenAuthConfig))
case _ => throw new IllegalArgumentException(s"Authorization type ${authConfig.getClass} is not supported")
}
if (CosmosClientMetrics.meterRegistry.isDefined) {
val customApplicationNameSuffix = cosmosClientConfiguration.customApplicationNameSuffix
.getOrElse("")
val clientCorrelationId = SparkSession.getActiveSession match {
case Some(session) =>
val ctx = session.sparkContext
if (Strings.isNullOrWhiteSpace(customApplicationNameSuffix)) {
s"${CosmosClientMetrics.executorId}-${ctx.appName}"
} else {
s"$customApplicationNameSuffix-${CosmosClientMetrics.executorId}-${ctx.appName}"
}
case None => customApplicationNameSuffix
}
val metricsOptions = new CosmosMicrometerMetricsOptions()
.meterRegistry(CosmosClientMetrics.meterRegistry.get)
.configureDefaultTagNames(
CosmosMetricTagName.CONTAINER,
CosmosMetricTagName.CLIENT_CORRELATION_ID,
CosmosMetricTagName.OPERATION,
CosmosMetricTagName.OPERATION_STATUS_CODE,
CosmosMetricTagName.PARTITION_KEY_RANGE_ID,
CosmosMetricTagName.SERVICE_ADDRESS,
CosmosMetricTagName.ADDRESS_RESOLUTION_COLLECTION_MAP_REFRESH,
CosmosMetricTagName.ADDRESS_RESOLUTION_FORCED_REFRESH,
CosmosMetricTagName.REQUEST_STATUS_CODE,
CosmosMetricTagName.REQUEST_OPERATION_TYPE
)
.setMetricCategories(
CosmosMetricCategory.SYSTEM,
CosmosMetricCategory.OPERATION_SUMMARY,
CosmosMetricCategory.REQUEST_SUMMARY,
CosmosMetricCategory.DIRECT_ADDRESS_RESOLUTIONS,
CosmosMetricCategory.DIRECT_REQUESTS,
CosmosMetricCategory.DIRECT_CHANNELS
)
val telemetryConfig = new CosmosClientTelemetryConfig()
.metricsOptions(metricsOptions)
.clientCorrelationId(clientCorrelationId)
builder.clientTelemetryConfig(telemetryConfig)
}
if (cosmosClientConfiguration.disableTcpConnectionEndpointRediscovery) {
builder.endpointDiscoveryEnabled(false)
}
if (cosmosClientConfiguration.readConsistencyStrategy != ReadConsistencyStrategy.DEFAULT) {
if (cosmosClientConfiguration.readConsistencyStrategy == ReadConsistencyStrategy.EVENTUAL) {
builder = builder.consistencyLevel(ConsistencyLevel.EVENTUAL)
} else {
builder = builder.readConsistencyStrategy(cosmosClientConfiguration.readConsistencyStrategy)
}
}
if (cosmosClientConfiguration.useGatewayMode) {
val gatewayCfg = new GatewayConnectionConfig()
.setMaxConnectionPoolSize(cosmosClientConfiguration.httpConnectionPoolSize)
builder = builder.gatewayMode(gatewayCfg)
} else {
var directConfig = new DirectConnectionConfig()
.setConnectTimeout(Duration.ofSeconds(CosmosConstants.defaultDirectRequestTimeoutInSeconds))
.setNetworkRequestTimeout(Duration.ofSeconds(CosmosConstants.defaultDirectRequestTimeoutInSeconds))
directConfig =
// Duplicate the default number of I/O threads per core
// We know that Spark often works with large payloads and we have seen
// indicators that the default number of I/O threads can be too low
// for workloads with large payloads
SparkBridgeImplementationInternal
.setIoThreadCountPerCoreFactor(directConfig, SparkBridgeImplementationInternal.getIoThreadCountPerCoreOverride)
directConfig =
// Spark workloads often result in very high CPU load
// We have seen indicators that increasing Thread priority for I/O threads
// can reduce transient I/O errors/timeouts in this case
SparkBridgeImplementationInternal
.setIoThreadPriority(directConfig, Thread.MAX_PRIORITY)
builder = builder.directMode(directConfig)
if (cosmosClientConfiguration.proactiveConnectionInitialization.isDefined &&
cosmosClientConfiguration.proactiveConnectionInitialization.get.nonEmpty) {
val containerIdentities = CosmosAccountConfig.parseProactiveConnectionInitConfigs(
cosmosClientConfiguration.proactiveConnectionInitialization.get)
val initConfig = new CosmosContainerProactiveInitConfigBuilder(containerIdentities)
.setAggressiveWarmupDuration(
Duration.ofSeconds(cosmosClientConfiguration.proactiveConnectionInitializationDurationInSeconds))
.setProactiveConnectionRegionsCount(1)
.build
builder.openConnectionsAndInitCaches(initConfig)
}
}
if (cosmosClientConfiguration.preferredRegionsList.isDefined) {
builder.preferredRegions(cosmosClientConfiguration.preferredRegionsList.get.toList.asJava)
}
if (cosmosClientConfiguration.enableClientTelemetry) {
System.setProperty(
"COSMOS.CLIENT_TELEMETRY_ENDPOINT",
cosmosClientConfiguration.clientTelemetryEndpoint.getOrElse(
"https://tools.cosmos.azure.com/api/clienttelemetry/trace"
))
System.setProperty(
"COSMOS.CLIENT_TELEMETRY_ENABLED",
"true")
builder.clientTelemetryEnabled(true)
}
// We saw incidents where even when Spark restarted Executors we haven't been able
// to recover - most likely due to stale cache state being broadcast
// Ideally the SDK would always be able to recover from stale cache state
// but the main purpose of broadcasting the cache state is to avoid peeks in metadata
// RU usage when multiple workers/executors are all started at the same time
// Skipping the broadcast cache state for retries should be safe - because not all executors
// will be restarted at the same time - and it adds an additional layer of safety.
val isTaskRetryAttempt: Boolean = TaskContext.get() != null && TaskContext.get().attemptNumber() > 0
val effectiveClientStateHandle = if (cosmosClientStateHandle.isDefined && !isTaskRetryAttempt) {
Some(cosmosClientStateHandle.get)
} else {
if (cosmosClientStateHandle.isDefined && isTaskRetryAttempt) {
logInfo(s"Ignoring broadcast client state handle because Task is getting retried. " +
s"Attempt Count: ${TaskContext.get().attemptNumber()}")
}
None
}
effectiveClientStateHandle match {
case Some(handle) =>
val metadataCache = handle
SparkBridgeImplementationInternal.setMetadataCacheSnapshot(builder, metadataCache)
case None =>
}
if (cosmosClientConfiguration.clientBuilderInterceptors.isDefined) {
logInfo(s"Applying CosmosClientBuilder interceptors")
for (interceptorFunction <- cosmosClientConfiguration.clientBuilderInterceptors.get) {
builder = interceptorFunction.apply(builder)
}
}
var client = builder.buildAsyncClient()
if (cosmosClientConfiguration.clientInterceptors.isDefined) {
logInfo(s"Applying CosmosClient interceptors")
for (interceptorFunction <- cosmosClientConfiguration.clientInterceptors.get) {
client = interceptorFunction.apply(client)
}
}
client
}