src/Core/Resolvers/MsSqlQueryExecutor.cs (287 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using System.Data;
using System.Data.Common;
using System.Net;
using System.Text;
using Azure.Core;
using Azure.DataApiBuilder.Config;
using Azure.DataApiBuilder.Config.ObjectModel;
using Azure.DataApiBuilder.Core.Authorization;
using Azure.DataApiBuilder.Core.Configurations;
using Azure.DataApiBuilder.Core.Models;
using Azure.DataApiBuilder.Service.Exceptions;
using Azure.Identity;
using Microsoft.AspNetCore.Http;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Logging;
namespace Azure.DataApiBuilder.Core.Resolvers
{
/// <summary>
/// Specialized QueryExecutor for MsSql mainly providing methods to
/// handle connecting to the database with a managed identity.
/// /// </summary>
public class MsSqlQueryExecutor : QueryExecutor<SqlConnection>
{
// This is the same scope for any Azure SQL database that is
// required to request a default azure credential access token
// for a managed identity.
public const string DATABASE_SCOPE = @"https://database.windows.net/.default";
/// <summary>
/// The managed identity Access Token string obtained
/// from the configuration controller.
/// Key: datasource name, Value: access token for this datasource.
/// </summary>
private Dictionary<string, string?> _accessTokensFromConfiguration;
/// <summary>
/// The MsSql specific connection string builders.
/// Key: datasource name, Value: connection string builder for this datasource.
/// </summary>
public override IDictionary<string, DbConnectionStringBuilder> ConnectionStringBuilders
=> base.ConnectionStringBuilders;
public DefaultAzureCredential AzureCredential { get; set; } = new();
/// <summary>
/// The saved cached access token obtained from DefaultAzureCredentials
/// representing a managed identity.
/// </summary>
private AccessToken? _defaultAccessToken;
/// <summary>
/// DatasourceName to boolean value indicating if access token should be set for db.
/// </summary>
private Dictionary<string, bool> _dataSourceAccessTokenUsage;
/// <summary>
/// DatasourceName to boolean value indicating if session context should be set for db.
/// </summary>
private Dictionary<string, bool> _dataSourceToSessionContextUsage;
private readonly RuntimeConfigProvider _runtimeConfigProvider;
private const string QUERYIDHEADER = "QueryIdentifyingIds";
public MsSqlQueryExecutor(
RuntimeConfigProvider runtimeConfigProvider,
DbExceptionParser dbExceptionParser,
ILogger<IQueryExecutor> logger,
IHttpContextAccessor httpContextAccessor,
HotReloadEventHandler<HotReloadEventArgs>? handler = null)
: base(dbExceptionParser,
logger,
runtimeConfigProvider,
httpContextAccessor,
handler)
{
_dataSourceAccessTokenUsage = new Dictionary<string, bool>();
_dataSourceToSessionContextUsage = new Dictionary<string, bool>();
_accessTokensFromConfiguration = runtimeConfigProvider.ManagedIdentityAccessToken;
_runtimeConfigProvider = runtimeConfigProvider;
ConfigureMsSqlQueryEecutor();
}
/// <summary>
/// Creates a SQLConnection to the data source of given name. This method also adds an event handler to
/// the connection's InfoMessage to extract the statement ID from the request and add it to httpcontext.
/// </summary>
/// <param name="dataSourceName">The name of the data source.</param>
/// <returns>The SQLConnection</returns>
/// <exception cref="DataApiBuilderException">Exception thrown if datasource is not found.</exception>
public override SqlConnection CreateConnection(string dataSourceName)
{
if (!ConnectionStringBuilders.ContainsKey(dataSourceName))
{
throw new DataApiBuilderException("Query execution failed. Could not find datasource to execute query against", HttpStatusCode.BadRequest, DataApiBuilderException.SubStatusCodes.DataSourceNotFound);
}
SqlConnection conn = new()
{
ConnectionString = ConnectionStringBuilders[dataSourceName].ConnectionString,
};
// Extract info message from SQLConnection
conn.InfoMessage += (object sender, SqlInfoMessageEventArgs e) =>
{
try
{
// Log the statement ids returned by the SQL engine when we executed the batch.
// This helps in correlating with SQL engine telemetry.
// If the info message has an error code that matches the well-known codes used for returning statement ID,
// then we can be certain that the message contains no PII.
IEnumerable<SqlError> errorsReceived = e.Errors.Cast<SqlError>();
IEnumerable<SqlInformationalCodes> allInfoCodesKnown = Enum.GetValues(typeof(SqlInformationalCodes)).Cast<SqlInformationalCodes>();
IEnumerable<string> infoErrorMessagesReceived = errorsReceived.Join(allInfoCodesKnown, error => error.Number, code => (int)code, (error, code) => error.Message);
foreach (string infoErrorMessageReceived in infoErrorMessagesReceived)
{
// Add statement ID to request
AddStatementIDToMiddlewareContext(infoErrorMessageReceived);
}
}
catch (Exception ex)
{
QueryExecutorLogger.LogError($"Error in info message handler while extracting query-identifying ID from SQLConnection. Error: {ex.Message}");
}
};
return conn;
}
/// <summary>
/// Configure during construction or a hot-reload scenario.
/// </summary>
private void ConfigureMsSqlQueryEecutor()
{
IEnumerable<KeyValuePair<string, DataSource>> mssqldbs = _runtimeConfigProvider.GetConfig().GetDataSourceNamesToDataSourcesIterator().Where(x => x.Value.DatabaseType is DatabaseType.MSSQL || x.Value.DatabaseType is DatabaseType.DWSQL);
foreach ((string dataSourceName, DataSource dataSource) in mssqldbs)
{
SqlConnectionStringBuilder builder = new(dataSource.ConnectionString);
if (_runtimeConfigProvider.IsLateConfigured)
{
builder.Encrypt = SqlConnectionEncryptOption.Mandatory;
builder.TrustServerCertificate = false;
}
ConnectionStringBuilders.TryAdd(dataSourceName, builder);
MsSqlOptions? msSqlOptions = dataSource.GetTypedOptions<MsSqlOptions>();
_dataSourceToSessionContextUsage[dataSourceName] = msSqlOptions is null ? false : msSqlOptions.SetSessionContext;
_dataSourceAccessTokenUsage[dataSourceName] = ShouldManagedIdentityAccessBeAttempted(builder);
}
}
/// <summary>
/// Modifies the properties of the supplied connection to support managed identity access.
/// In the case of MsSql, gets access token if deemed necessary and sets it on the connection.
/// The supplied connection is assumed to already have the same connection string
/// provided in the runtime configuration.
/// </summary>
/// <param name="conn">The supplied connection to modify for managed identity access.</param>
/// <param name="dataSourceName">Name of datasource for which to set access token. Default dbName taken from config if null</param>
public override async Task SetManagedIdentityAccessTokenIfAnyAsync(DbConnection conn, string dataSourceName)
{
// using default datasource name for first db - maintaining backward compatibility for single db scenario.
if (string.IsNullOrEmpty(dataSourceName))
{
dataSourceName = ConfigProvider.GetConfig().DefaultDataSourceName;
}
_dataSourceAccessTokenUsage.TryGetValue(dataSourceName, out bool setAccessToken);
// Only attempt to get the access token if the connection string is in the appropriate format
if (setAccessToken)
{
SqlConnection sqlConn = (SqlConnection)conn;
// If the configuration controller provided a managed identity access token use that,
// else use the default saved access token if still valid.
// Get a new token only if the saved token is null or expired.
_accessTokensFromConfiguration.TryGetValue(dataSourceName, out string? accessTokenFromController);
string? accessToken = accessTokenFromController ??
(IsDefaultAccessTokenValid() ?
((AccessToken)_defaultAccessToken!).Token :
await GetAccessTokenAsync());
if (accessToken is not null)
{
sqlConn.AccessToken = accessToken;
}
}
}
/// <summary>
/// Determines if managed identity access should be attempted or not.
/// It should only be attempted,
/// 1. If none of UserID, Password or Authentication
/// method are specified in the connection string since they have higher precedence
/// and any attempt to use an access token in their presence would lead to
/// a System.InvalidOperationException.
/// 2. It is NOT a Windows Integrated Security scenario.
/// </summary>
private static bool ShouldManagedIdentityAccessBeAttempted(SqlConnectionStringBuilder builder)
{
return string.IsNullOrEmpty(builder.UserID) &&
string.IsNullOrEmpty(builder.Password) &&
builder.Authentication == SqlAuthenticationMethod.NotSpecified &&
!builder.IntegratedSecurity;
}
/// <summary>
/// Determines if the saved default azure credential's access token is valid and not expired.
/// </summary>
/// <returns>True if valid, false otherwise.</returns>
private bool IsDefaultAccessTokenValid()
{
return _defaultAccessToken is not null &&
((AccessToken)_defaultAccessToken).ExpiresOn.CompareTo(DateTimeOffset.Now) > 0;
}
/// <summary>
/// Tries to get an access token using DefaultAzureCredentials.
/// Catches any CredentialUnavailableException and logs only a warning
/// since since this is best effort.
/// </summary>
/// <returns>The string representation of the access token if found,
/// null otherwise.</returns>
private async Task<string?> GetAccessTokenAsync()
{
try
{
_defaultAccessToken = await AzureCredential.GetTokenAsync(new TokenRequestContext(new[] { DATABASE_SCOPE }));
}
catch (CredentialUnavailableException ex)
{
string correlationId = HttpContextExtensions.GetLoggerCorrelationId(HttpContextAccessor.HttpContext);
QueryExecutorLogger.LogWarning(
message: "{correlationId} Failed to retrieve a managed identity access token using DefaultAzureCredential due to:\n{errorMessage}",
correlationId,
ex.Message);
}
return _defaultAccessToken?.Token;
}
/// <summary>
/// Method to generate the query to send user data to the underlying database via SESSION_CONTEXT which might be used
/// for additional security (eg. using Security Policies) at the database level. The max payload limit for SESSION_CONTEXT is 1MB.
/// </summary>
/// <param name="httpContext">Current user httpContext.</param>
/// <param name="parameters">Dictionary of parameters/value required to execute the query.</param>
/// <param name="dataSourceName">Name of datasource for which to set access token. Default dbName taken from config if null</param>
/// <returns>empty string / query to set session parameters for the connection.</returns>
/// <seealso cref="https://learn.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-set-session-context-transact-sql?view=sql-server-ver16"/>
public override string GetSessionParamsQuery(HttpContext? httpContext, IDictionary<string, DbConnectionParam> parameters, string dataSourceName)
{
if (string.IsNullOrEmpty(dataSourceName))
{
dataSourceName = ConfigProvider.GetConfig().DefaultDataSourceName;
}
if (httpContext is null || !_dataSourceToSessionContextUsage[dataSourceName])
{
return string.Empty;
}
// Dictionary containing all the claims belonging to the user, to be used as session parameters.
Dictionary<string, string> sessionParams = AuthorizationResolver.GetProcessedUserClaims(httpContext);
// Counter to generate different param name for each of the sessionParam.
IncrementingInteger counter = new();
const string SESSION_PARAM_NAME = $"{BaseQueryStructure.PARAM_NAME_PREFIX}session_param";
StringBuilder sessionMapQuery = new();
foreach ((string claimType, string claimValue) in sessionParams)
{
string paramName = $"{SESSION_PARAM_NAME}{counter.Next()}";
parameters.Add(paramName, new(claimValue));
// Append statement to set read only param value - can be set only once for a connection.
string statementToSetReadOnlyParam = "EXEC sp_set_session_context " + $"'{claimType}', " + paramName + ", @read_only = 1;";
sessionMapQuery = sessionMapQuery.Append(statementToSetReadOnlyParam);
}
return sessionMapQuery.ToString();
}
/// <inheritdoc/>
public override async Task<DbResultSet> GetMultipleResultSetsIfAnyAsync(
DbDataReader dbDataReader, List<string>? args = null)
{
// From the first result set, we get the count(0/1) of records with given PK.
DbResultSet resultSetWithCountOfRowsWithGivenPk = await ExtractResultSetFromDbDataReaderAsync(dbDataReader);
DbResultSetRow? resultSetRowWithCountOfRowsWithGivenPk = resultSetWithCountOfRowsWithGivenPk.Rows.FirstOrDefault();
int numOfRecordsWithGivenPK;
if (resultSetRowWithCountOfRowsWithGivenPk is not null &&
resultSetRowWithCountOfRowsWithGivenPk.Columns.TryGetValue(MsSqlQueryBuilder.COUNT_ROWS_WITH_GIVEN_PK, out object? rowsWithGivenPK))
{
numOfRecordsWithGivenPK = (int)rowsWithGivenPK!;
}
else
{
throw new DataApiBuilderException(
message: $"Neither insert nor update could be performed.",
statusCode: HttpStatusCode.InternalServerError,
subStatusCode: DataApiBuilderException.SubStatusCodes.UnexpectedError);
}
// The second result set holds the records returned as a result of the executed update/insert operation.
DbResultSet? dbResultSet = await dbDataReader.NextResultAsync() ? await ExtractResultSetFromDbDataReaderAsync(dbDataReader) : null;
if (dbResultSet is null)
{
// For a PUT/PATCH operation on a table/view with non-autogen PK, we would either perform an insert or an update for sure,
// and correspondingly dbResultSet can not be null.
// However, in case of autogen PK, we would not attempt an insert since PK is auto generated.
// We would only attempt an update , and that too when a record exists for given PK.
// However since the dbResultSet is null here, it indicates we didn't perform an update either.
// This happens when count of rows with given PK = 0.
if (args is not null && args.Count > 1)
{
string prettyPrintPk = args![0];
string entityName = args[1];
throw new DataApiBuilderException(
message: $"Cannot perform INSERT and could not find {entityName} " +
$"with primary key {prettyPrintPk} to perform UPDATE on.",
statusCode: HttpStatusCode.NotFound,
subStatusCode: DataApiBuilderException.SubStatusCodes.ItemNotFound);
}
throw new DataApiBuilderException(
message: $"Neither insert nor update could be performed.",
statusCode: HttpStatusCode.InternalServerError,
subStatusCode: DataApiBuilderException.SubStatusCodes.UnexpectedError);
}
if (numOfRecordsWithGivenPK == 1) // This indicates that a record existed with given PK and we attempted an update operation.
{
if (dbResultSet.Rows.Count == 0)
{
// Record exists in the table/view but no record updated - indicates database policy failure.
throw new DataApiBuilderException(
message: DataApiBuilderException.AUTHORIZATION_FAILURE,
statusCode: HttpStatusCode.Forbidden,
subStatusCode: DataApiBuilderException.SubStatusCodes.DatabasePolicyFailure);
}
// This is used as an identifier to distinguish between update/insert operations.
// Later helps to add location header in case of insert operation.
dbResultSet.ResultProperties.Add(SqlMutationEngine.IS_UPDATE_RESULT_SET, true);
}
else if (dbResultSet.Rows.Count == 0)
{
// No record exists in the table/view but inserted no records - indicates database policy failure.
throw new DataApiBuilderException(
message: DataApiBuilderException.AUTHORIZATION_FAILURE,
statusCode: HttpStatusCode.Forbidden,
subStatusCode: DataApiBuilderException.SubStatusCodes.DatabasePolicyFailure);
}
return dbResultSet;
}
/// <inheritdoc />
public override SqlCommand PrepareDbCommand(
SqlConnection conn,
string sqltext,
IDictionary<string, DbConnectionParam> parameters,
HttpContext? httpContext,
string dataSourceName)
{
SqlCommand cmd = conn.CreateCommand();
cmd.CommandType = CommandType.Text;
// Add query to send user data from DAB to the underlying database to enable additional security the user might have configured
// at the database level.
string sessionParamsQuery = GetSessionParamsQuery(httpContext, parameters, dataSourceName);
cmd.CommandText = sessionParamsQuery + sqltext;
if (parameters is not null)
{
foreach (KeyValuePair<string, DbConnectionParam> parameterEntry in parameters)
{
SqlParameter parameter = cmd.CreateParameter();
parameter.ParameterName = parameterEntry.Key;
parameter.Value = parameterEntry.Value.Value ?? DBNull.Value;
PopulateDbTypeForParameter(parameterEntry, parameter);
cmd.Parameters.Add(parameter);
}
}
return cmd;
}
/// <inheritdoc/>
public static void PopulateDbTypeForParameter(KeyValuePair<string, DbConnectionParam> parameterEntry, SqlParameter parameter)
{
if (parameterEntry.Value is not null)
{
if (parameterEntry.Value.DbType is not null)
{
parameter.DbType = (DbType)parameterEntry.Value.DbType;
}
if (parameterEntry.Value.SqlDbType is not null)
{
parameter.SqlDbType = (SqlDbType)parameterEntry.Value.SqlDbType;
}
}
}
private void AddStatementIDToMiddlewareContext(string statementId)
{
HttpContext? httpContext = HttpContextAccessor?.HttpContext;
if (httpContext != null)
{
// locking is because we could have multiple queries in a single http request and each query will be processed in parallel leading to concurrent access of the httpContext.Items.
lock (_httpContextLock)
{
if (httpContext.Items.TryGetValue(QUERYIDHEADER, out object? currentValue) && currentValue is not null)
{
try
{
httpContext.Items[QUERYIDHEADER] = (string)currentValue + ";" + statementId;
}
catch
{
QueryExecutorLogger.LogWarning("Could not cast query identifying ID to string. The ID was not added to httpcontext");
return;
}
}
else
{
httpContext.Items[QUERYIDHEADER] = statementId;
}
}
}
}
}
}