private[this] def createCosmosAsyncClient()

in sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala [209:418]


  private[this] def createCosmosAsyncClient(cosmosClientConfiguration: CosmosClientConfiguration,
                                            cosmosClientStateHandle: Option[CosmosClientMetadataCachesSnapshot]): CosmosAsyncClient = {
      if (cosmosClientConfiguration.enforceNativeTransport && !io.netty.channel.epoll.Epoll.isAvailable) {
        throw new IllegalStateException(
          "The enforcement of native transport is enabled in your configuration and native transport is not " +
            "available. Either ensure `spark.cosmos.enforceNativeTransport` is set to false or make " +
            "sure you use a Spark environment supporting native transport.",
          io.netty.channel.epoll.Epoll.unavailabilityCause()
        )
      }

      var builder = new CosmosClientBuilder()
          .endpoint(cosmosClientConfiguration.endpoint)
          .userAgentSuffix(cosmosClientConfiguration.applicationName)
          .throttlingRetryOptions(
              new ThrottlingRetryOptions()
                  .setMaxRetryAttemptsOnThrottledRequests(Int.MaxValue)
                  .setMaxRetryWaitTime(Duration.ofSeconds((Integer.MAX_VALUE / 1000) - 1)))

      val authConfig = cosmosClientConfiguration.authConfig
      authConfig match {
          case masterKeyAuthConfig: CosmosMasterKeyAuthConfig => builder.key(masterKeyAuthConfig.accountKey)
          case servicePrincipalAuthConfig: CosmosServicePrincipalAuthConfig =>
              val tokenCredential = if (servicePrincipalAuthConfig.clientCertPemBase64.isDefined) {
                val certInputStream = new ByteArrayInputStream(Base64.getDecoder.decode(servicePrincipalAuthConfig.clientCertPemBase64.get))
                new ClientCertificateCredentialBuilder()
                  .authorityHost(new AzureEnvironment(cosmosClientConfiguration.azureEnvironmentEndpoints).getActiveDirectoryEndpoint())
                  .tenantId(servicePrincipalAuthConfig.tenantId)
                  .clientId(servicePrincipalAuthConfig.clientId)
                  .pemCertificate(certInputStream)
                  .sendCertificateChain(servicePrincipalAuthConfig.sendChain)
                  .build()
              } else {
                new ClientSecretCredentialBuilder()
                  .authorityHost(new AzureEnvironment(cosmosClientConfiguration.azureEnvironmentEndpoints).getActiveDirectoryEndpoint())
                  .tenantId(servicePrincipalAuthConfig.tenantId)
                  .clientId(servicePrincipalAuthConfig.clientId)
                  .clientSecret(servicePrincipalAuthConfig.clientSecret.get)
                  .build()
              }

              builder.credential(tokenCredential)
          case managedIdentityAuthConfig: CosmosManagedIdentityAuthConfig =>
            builder.credential(createTokenCredential(managedIdentityAuthConfig))
          case accessTokenAuthConfig: CosmosAccessTokenAuthConfig =>
            builder.credential(createTokenCredential(accessTokenAuthConfig))
          case _ => throw new IllegalArgumentException(s"Authorization type ${authConfig.getClass} is not supported")
      }

      if (CosmosClientMetrics.meterRegistry.isDefined) {
          val customApplicationNameSuffix = cosmosClientConfiguration.customApplicationNameSuffix
              .getOrElse("")

          val clientCorrelationId = SparkSession.getActiveSession match {
              case Some(session) =>
                  val ctx = session.sparkContext

                  if (Strings.isNullOrWhiteSpace(customApplicationNameSuffix)) {
                      s"${CosmosClientMetrics.executorId}-${ctx.appName}"
                  } else {
                      s"$customApplicationNameSuffix-${CosmosClientMetrics.executorId}-${ctx.appName}"
                  }
              case None => customApplicationNameSuffix
          }

          val metricsOptions = new CosmosMicrometerMetricsOptions()
            .meterRegistry(CosmosClientMetrics.meterRegistry.get)
            .configureDefaultTagNames(
              CosmosMetricTagName.CONTAINER,
              CosmosMetricTagName.CLIENT_CORRELATION_ID,
              CosmosMetricTagName.OPERATION,
              CosmosMetricTagName.OPERATION_STATUS_CODE,
              CosmosMetricTagName.PARTITION_KEY_RANGE_ID,
              CosmosMetricTagName.SERVICE_ADDRESS,
              CosmosMetricTagName.ADDRESS_RESOLUTION_COLLECTION_MAP_REFRESH,
              CosmosMetricTagName.ADDRESS_RESOLUTION_FORCED_REFRESH,
              CosmosMetricTagName.REQUEST_STATUS_CODE,
              CosmosMetricTagName.REQUEST_OPERATION_TYPE
            )
            .setMetricCategories(
              CosmosMetricCategory.SYSTEM,
              CosmosMetricCategory.OPERATION_SUMMARY,
              CosmosMetricCategory.REQUEST_SUMMARY,
              CosmosMetricCategory.DIRECT_ADDRESS_RESOLUTIONS,
              CosmosMetricCategory.DIRECT_REQUESTS,
              CosmosMetricCategory.DIRECT_CHANNELS
            )

          val telemetryConfig = new CosmosClientTelemetryConfig()
              .metricsOptions(metricsOptions)
              .clientCorrelationId(clientCorrelationId)

          builder.clientTelemetryConfig(telemetryConfig)
      }

      if (cosmosClientConfiguration.disableTcpConnectionEndpointRediscovery) {
          builder.endpointDiscoveryEnabled(false)
      }

      if (cosmosClientConfiguration.readConsistencyStrategy != ReadConsistencyStrategy.DEFAULT) {
        if (cosmosClientConfiguration.readConsistencyStrategy == ReadConsistencyStrategy.EVENTUAL) {
          builder = builder.consistencyLevel(ConsistencyLevel.EVENTUAL)
        } else {
          builder = builder.readConsistencyStrategy(cosmosClientConfiguration.readConsistencyStrategy)
        }
      }

      if (cosmosClientConfiguration.useGatewayMode) {
          val gatewayCfg = new GatewayConnectionConfig()
              .setMaxConnectionPoolSize(cosmosClientConfiguration.httpConnectionPoolSize)
          builder = builder.gatewayMode(gatewayCfg)
      } else {
          var directConfig = new DirectConnectionConfig()
              .setConnectTimeout(Duration.ofSeconds(CosmosConstants.defaultDirectRequestTimeoutInSeconds))
              .setNetworkRequestTimeout(Duration.ofSeconds(CosmosConstants.defaultDirectRequestTimeoutInSeconds))

          directConfig =
          // Duplicate the default number of I/O threads per core
          // We know that Spark often works with large payloads and we have seen
          // indicators that the default number of I/O threads can be too low
          // for workloads with large payloads
              SparkBridgeImplementationInternal
                  .setIoThreadCountPerCoreFactor(directConfig, SparkBridgeImplementationInternal.getIoThreadCountPerCoreOverride)

          directConfig =
          // Spark workloads often result in very high CPU load
          // We have seen indicators that increasing Thread priority for I/O threads
          // can reduce transient I/O errors/timeouts in this case
              SparkBridgeImplementationInternal
                  .setIoThreadPriority(directConfig, Thread.MAX_PRIORITY)

          builder = builder.directMode(directConfig)

          if (cosmosClientConfiguration.proactiveConnectionInitialization.isDefined &&
            cosmosClientConfiguration.proactiveConnectionInitialization.get.nonEmpty) {
            val containerIdentities = CosmosAccountConfig.parseProactiveConnectionInitConfigs(
              cosmosClientConfiguration.proactiveConnectionInitialization.get)

            val initConfig = new CosmosContainerProactiveInitConfigBuilder(containerIdentities)
              .setAggressiveWarmupDuration(
                Duration.ofSeconds(cosmosClientConfiguration.proactiveConnectionInitializationDurationInSeconds))
              .setProactiveConnectionRegionsCount(1)
              .build
            builder.openConnectionsAndInitCaches(initConfig)
          }
      }

      if (cosmosClientConfiguration.preferredRegionsList.isDefined) {
          builder.preferredRegions(cosmosClientConfiguration.preferredRegionsList.get.toList.asJava)
      }

      if (cosmosClientConfiguration.enableClientTelemetry) {
          System.setProperty(
              "COSMOS.CLIENT_TELEMETRY_ENDPOINT",
              cosmosClientConfiguration.clientTelemetryEndpoint.getOrElse(
                  "https://tools.cosmos.azure.com/api/clienttelemetry/trace"
              ))
          System.setProperty(
              "COSMOS.CLIENT_TELEMETRY_ENABLED",
              "true")

          builder.clientTelemetryEnabled(true)
      }

      // We saw incidents where even when Spark restarted Executors we haven't been able
      // to recover - most likely due to stale cache state being broadcast
      // Ideally the SDK would always be able to recover from stale cache state
      // but the main purpose of broadcasting the cache state is to avoid peeks in metadata
      // RU usage when multiple workers/executors are all started at the same time
      // Skipping the broadcast cache state for retries should be safe - because not all executors
      // will be restarted at the same time - and it adds an additional layer of safety.
      val isTaskRetryAttempt: Boolean = TaskContext.get() != null && TaskContext.get().attemptNumber() > 0

      val effectiveClientStateHandle = if (cosmosClientStateHandle.isDefined && !isTaskRetryAttempt) {
          Some(cosmosClientStateHandle.get)
      } else {

          if (cosmosClientStateHandle.isDefined && isTaskRetryAttempt) {
              logInfo(s"Ignoring broadcast client state handle because Task is getting retried. " +
                  s"Attempt Count: ${TaskContext.get().attemptNumber()}")
          }

          None
      }

      effectiveClientStateHandle match {
          case Some(handle) =>
              val metadataCache = handle
              SparkBridgeImplementationInternal.setMetadataCacheSnapshot(builder, metadataCache)
          case None =>
      }

      if (cosmosClientConfiguration.clientBuilderInterceptors.isDefined) {
        logInfo(s"Applying CosmosClientBuilder interceptors")
        for (interceptorFunction <- cosmosClientConfiguration.clientBuilderInterceptors.get) {
          builder = interceptorFunction.apply(builder)
        }
      }

      var client = builder.buildAsyncClient()

    if (cosmosClientConfiguration.clientInterceptors.isDefined) {
      logInfo(s"Applying CosmosClient interceptors")
      for (interceptorFunction <- cosmosClientConfiguration.clientInterceptors.get) {
        client = interceptorFunction.apply(client)
      }
    }

    client
  }