csharp/src/Drivers/Databricks/DatabricksConnection.cs (340 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Net.Http; using System.Net.Http.Headers; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache; using Apache.Arrow.Adbc.Drivers.Apache.Spark; using Apache.Arrow.Adbc.Drivers.Databricks.Auth; using Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch; using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; namespace Apache.Arrow.Adbc.Drivers.Databricks { internal class DatabricksConnection : SparkHttpConnection { private bool _applySSPWithQueries = false; private bool _enableDirectResults = true; internal static TSparkGetDirectResults defaultGetDirectResults = new() { MaxRows = 2000000, MaxBytes = 404857600 }; // CloudFetch configuration private const long DefaultMaxBytesPerFile = 20 * 1024 * 1024; // 20MB private bool _useCloudFetch = true; private bool _canDecompressLz4 = true; private long _maxBytesPerFile = DefaultMaxBytesPerFile; private const bool DefaultRetryOnUnavailable= true; private const int DefaultTemporarilyUnavailableRetryTimeout = 500; public DatabricksConnection(IReadOnlyDictionary<string, string> properties) : base(properties) { ValidateProperties(); } private void ValidateProperties() { if (Properties.TryGetValue(DatabricksParameters.ApplySSPWithQueries, out string? applySSPWithQueriesStr)) { if (bool.TryParse(applySSPWithQueriesStr, out bool applySSPWithQueriesValue)) { _applySSPWithQueries = applySSPWithQueriesValue; } else { throw new ArgumentException($"Parameter '{DatabricksParameters.ApplySSPWithQueries}' value '{applySSPWithQueriesStr}' could not be parsed. Valid values are 'true' and 'false'."); } } if (Properties.TryGetValue(DatabricksParameters.EnableDirectResults, out string? enableDirectResultsStr)) { if (bool.TryParse(enableDirectResultsStr, out bool enableDirectResultsValue)) { _enableDirectResults = enableDirectResultsValue; } else { throw new ArgumentException($"Parameter '{DatabricksParameters.EnableDirectResults}' value '{enableDirectResultsStr}' could not be parsed. Valid values are 'true' and 'false'."); } } // Parse CloudFetch options from connection properties if (Properties.TryGetValue(DatabricksParameters.UseCloudFetch, out string? useCloudFetchStr)) { if (bool.TryParse(useCloudFetchStr, out bool useCloudFetchValue)) { _useCloudFetch = useCloudFetchValue; } else { throw new ArgumentException($"Parameter '{DatabricksParameters.UseCloudFetch}' value '{useCloudFetchStr}' could not be parsed. Valid values are 'true' and 'false'."); } } if (Properties.TryGetValue(DatabricksParameters.CanDecompressLz4, out string? canDecompressLz4Str)) { if (bool.TryParse(canDecompressLz4Str, out bool canDecompressLz4Value)) { _canDecompressLz4 = canDecompressLz4Value; } else { throw new ArgumentException($"Parameter '{DatabricksParameters.CanDecompressLz4}' value '{canDecompressLz4Str}' could not be parsed. Valid values are 'true' and 'false'."); } } if (Properties.TryGetValue(DatabricksParameters.MaxBytesPerFile, out string? maxBytesPerFileStr)) { if (!long.TryParse(maxBytesPerFileStr, out long maxBytesPerFileValue)) { throw new ArgumentException($"Parameter '{DatabricksParameters.MaxBytesPerFile}' value '{maxBytesPerFileStr}' could not be parsed. Valid values are positive integers."); } if (maxBytesPerFileValue <= 0) { throw new ArgumentOutOfRangeException( nameof(Properties), maxBytesPerFileValue, $"Parameter '{DatabricksParameters.MaxBytesPerFile}' value must be a positive integer."); } _maxBytesPerFile = maxBytesPerFileValue; } } /// <summary> /// Gets whether server side properties should be applied using queries. /// </summary> internal bool ApplySSPWithQueries => _applySSPWithQueries; /// <summary> /// Gets whether direct results are enabled. /// </summary> internal bool EnableDirectResults => _enableDirectResults; /// <summary> /// Gets whether CloudFetch is enabled. /// </summary> internal bool UseCloudFetch => _useCloudFetch; /// <summary> /// Gets whether LZ4 decompression is enabled. /// </summary> internal bool CanDecompressLz4 => _canDecompressLz4; /// <summary> /// Gets the maximum bytes per file for CloudFetch. /// </summary> internal long MaxBytesPerFile => _maxBytesPerFile; /// <summary> /// Gets a value indicating whether to retry requests that receive a 503 response with a Retry-After header. /// </summary> protected bool TemporarilyUnavailableRetry { get; private set; } = DefaultRetryOnUnavailable; /// <summary> /// Gets the maximum total time in seconds to retry 503 responses before failing. /// </summary> protected int TemporarilyUnavailableRetryTimeout { get; private set; } = DefaultTemporarilyUnavailableRetryTimeout; protected override HttpMessageHandler CreateHttpHandler() { HttpMessageHandler baseHandler = base.CreateHttpHandler(); if (TemporarilyUnavailableRetry) { // Add OAuth handler if OAuth authentication is being used baseHandler = new RetryHttpHandler(baseHandler, TemporarilyUnavailableRetryTimeout); } // Add OAuth handler if OAuth authentication is being used if (Properties.TryGetValue(SparkParameters.AuthType, out string? authType) && SparkAuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue) && authTypeValue == SparkAuthType.OAuth && Properties.TryGetValue(DatabricksParameters.OAuthGrantType, out string? grantTypeStr) && DatabricksOAuthGrantTypeParser.TryParse(grantTypeStr, out DatabricksOAuthGrantType grantType) && grantType == DatabricksOAuthGrantType.ClientCredentials) { // Note: We assume that properties have already been validated if (Properties.TryGetValue(SparkParameters.HostName, out string? host) && !string.IsNullOrEmpty(host)) { // Use hostname directly if provided } else if (Properties.TryGetValue(AdbcOptions.Uri, out string? uri) && !string.IsNullOrEmpty(uri)) { // Extract hostname from URI if URI is provided if (Uri.TryCreate(uri, UriKind.Absolute, out Uri? parsedUri)) { host = parsedUri.Host; } } Properties.TryGetValue(DatabricksParameters.OAuthClientId, out string? clientId); Properties.TryGetValue(DatabricksParameters.OAuthClientSecret, out string? clientSecret); var tokenProvider = new OAuthClientCredentialsProvider( clientId!, clientSecret!, host!, timeoutMinutes: 1 ); return new OAuthDelegatingHandler(baseHandler, tokenProvider); } return baseHandler; } protected internal override bool AreResultsAvailableDirectly => _enableDirectResults; protected override void SetDirectResults(TGetColumnsReq request) => request.GetDirectResults = defaultGetDirectResults; protected override void SetDirectResults(TGetCatalogsReq request) => request.GetDirectResults = defaultGetDirectResults; protected override void SetDirectResults(TGetSchemasReq request) => request.GetDirectResults = defaultGetDirectResults; protected override void SetDirectResults(TGetTablesReq request) => request.GetDirectResults = defaultGetDirectResults; protected override void SetDirectResults(TGetTableTypesReq request) => request.GetDirectResults = defaultGetDirectResults; protected override void SetDirectResults(TGetPrimaryKeysReq request) => request.GetDirectResults = defaultGetDirectResults; protected override void SetDirectResults(TGetCrossReferenceReq request) => request.GetDirectResults = defaultGetDirectResults; internal override IArrowArrayStream NewReader<T>(T statement, Schema schema, TGetResultSetMetadataResp? metadataResp = null) { // Get result format from metadata response if available TSparkRowSetType resultFormat = TSparkRowSetType.ARROW_BASED_SET; bool isLz4Compressed = false; DatabricksStatement? databricksStatement = statement as DatabricksStatement; if (databricksStatement == null) { throw new InvalidOperationException("Cannot obtain a reader for Databricks"); } if (metadataResp != null) { if (metadataResp.__isset.resultFormat) { resultFormat = metadataResp.ResultFormat; } if (metadataResp.__isset.lz4Compressed) { isLz4Compressed = metadataResp.Lz4Compressed; } } // Choose the appropriate reader based on the result format if (resultFormat == TSparkRowSetType.URL_BASED_SET) { return new CloudFetchReader(databricksStatement, schema, isLz4Compressed); } else { return new DatabricksReader(databricksStatement, schema, isLz4Compressed); } } internal override SchemaParser SchemaParser => new DatabricksSchemaParser(); public override AdbcStatement CreateStatement() { DatabricksStatement statement = new DatabricksStatement(this); return statement; } protected override TOpenSessionReq CreateSessionRequest() { var req = new TOpenSessionReq { Client_protocol = TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, Client_protocol_i64 = (long)TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, CanUseMultipleCatalogs = true, }; // If not using queries to set server-side properties, include them in Configuration if (!_applySSPWithQueries) { req.Configuration = new Dictionary<string, string>(); var serverSideProperties = GetServerSideProperties(); foreach (var property in serverSideProperties) { req.Configuration[property.Key] = property.Value; } } return req; } /// <summary> /// Gets a dictionary of server-side properties extracted from connection properties. /// </summary> /// <returns>Dictionary of server-side properties with prefix removed from keys.</returns> private Dictionary<string, string> GetServerSideProperties() { return Properties .Where(p => p.Key.StartsWith(DatabricksParameters.ServerSidePropertyPrefix)) .ToDictionary( p => p.Key.Substring(DatabricksParameters.ServerSidePropertyPrefix.Length), p => p.Value ); } /// <summary> /// Applies server-side properties by executing "set key=value" queries. /// </summary> /// <returns>A task representing the asynchronous operation.</returns> public async Task ApplyServerSidePropertiesAsync() { if (!_applySSPWithQueries) { return; } var serverSideProperties = GetServerSideProperties(); if (serverSideProperties.Count == 0) { return; } using var statement = new DatabricksStatement(this); foreach (var property in serverSideProperties) { if (!IsValidPropertyName(property.Key)) { Debug.WriteLine($"Skipping invalid property name: {property.Key}"); continue; } string escapedValue = EscapeSqlString(property.Value); string query = $"SET {property.Key}={escapedValue}"; statement.SqlQuery = query; try { await statement.ExecuteUpdateAsync(); } catch (Exception ex) { Debug.WriteLine($"Error setting server-side property '{property.Key}': {ex.Message}"); } } } private bool IsValidPropertyName(string propertyName) { // Allow only letters and underscores in property names return System.Text.RegularExpressions.Regex.IsMatch( propertyName, @"^[a-zA-Z_]+$"); } private string EscapeSqlString(string value) { return "`" + value.Replace("`", "``") + "`"; } protected override void ValidateOptions() { base.ValidateOptions(); if (Properties.TryGetValue(DatabricksParameters.TemporarilyUnavailableRetry, out string? tempUnavailableRetryStr)) { if (!bool.TryParse(tempUnavailableRetryStr, out bool tempUnavailableRetryValue)) { throw new ArgumentOutOfRangeException(DatabricksParameters.TemporarilyUnavailableRetry, tempUnavailableRetryStr, $"must be a value of false (disabled) or true (enabled). Default is true."); } TemporarilyUnavailableRetry = tempUnavailableRetryValue; } if(Properties.TryGetValue(DatabricksParameters.TemporarilyUnavailableRetryTimeout, out string? tempUnavailableRetryTimeoutStr)) { if (!int.TryParse(tempUnavailableRetryTimeoutStr, out int tempUnavailableRetryTimeoutValue) || tempUnavailableRetryTimeoutValue < 0) { throw new ArgumentOutOfRangeException(DatabricksParameters.TemporarilyUnavailableRetryTimeout, tempUnavailableRetryTimeoutStr, $"must be a value of 0 (retry indefinitely) or a positive integer representing seconds. Default is 900 seconds (15 minutes)."); } TemporarilyUnavailableRetryTimeout = tempUnavailableRetryTimeoutValue; } } protected override Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSetMetadata); protected override Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSetMetadata); protected override Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSetMetadata); protected override Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSetMetadata); protected internal override Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetPrimaryKeysResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSetMetadata); protected override Task<TRowSet> GetRowSetAsync(TGetTableTypesResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSet.Results); protected override Task<TRowSet> GetRowSetAsync(TGetColumnsResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSet.Results); protected override Task<TRowSet> GetRowSetAsync(TGetTablesResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSet.Results); protected override Task<TRowSet> GetRowSetAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSet.Results); protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSet.Results); protected internal override Task<TRowSet> GetRowSetAsync(TGetPrimaryKeysResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSet.Results); protected override AuthenticationHeaderValue? GetAuthenticationHeaderValue(SparkAuthType authType) { if (authType == SparkAuthType.OAuth) { Properties.TryGetValue(DatabricksParameters.OAuthGrantType, out string? grantTypeStr); if (DatabricksOAuthGrantTypeParser.TryParse(grantTypeStr, out DatabricksOAuthGrantType grantType) && grantType == DatabricksOAuthGrantType.ClientCredentials) { // Return null for client credentials flow since OAuth handler will handle authentication return null; } } return base.GetAuthenticationHeaderValue(authType); } protected override void ValidateOAuthParameters() { Properties.TryGetValue(DatabricksParameters.OAuthGrantType, out string? grantTypeStr); DatabricksOAuthGrantType grantType; if (!DatabricksOAuthGrantTypeParser.TryParse(grantTypeStr, out grantType)) { throw new ArgumentOutOfRangeException( DatabricksParameters.OAuthGrantType, grantTypeStr, $"Unsupported {DatabricksParameters.OAuthGrantType} value. Refer to the Databricks documentation for valid values." ); } // If we have a valid grant type, validate the required parameters if (grantType == DatabricksOAuthGrantType.ClientCredentials) { Properties.TryGetValue(DatabricksParameters.OAuthClientId, out string? clientId); Properties.TryGetValue(DatabricksParameters.OAuthClientSecret, out string? clientSecret); if (string.IsNullOrEmpty(clientId)) { throw new ArgumentException( $"Parameter '{DatabricksParameters.OAuthGrantType}' is set to '{DatabricksConstants.OAuthGrantTypes.ClientCredentials}' but parameter '{DatabricksParameters.OAuthClientId}' is not set. Please provide a value for '{DatabricksParameters.OAuthClientId}'.", nameof(Properties)); } if (string.IsNullOrEmpty(clientSecret)) { throw new ArgumentException( $"Parameter '{DatabricksParameters.OAuthGrantType}' is set to '{DatabricksConstants.OAuthGrantTypes.ClientCredentials}' but parameter '{DatabricksParameters.OAuthClientSecret}' is not set. Please provide a value for '{DatabricksParameters.OAuthClientSecret}'.", nameof(Properties)); } } else { // For other auth flows, use default OAuth validation base.ValidateOAuthParameters(); } } } }