src/Core/Resolvers/CosmosQueryEngine.cs (325 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#nullable disable
using System.Text;
using System.Text.Json;
using Azure.DataApiBuilder.Auth;
using Azure.DataApiBuilder.Config.ObjectModel;
using Azure.DataApiBuilder.Core.Configurations;
using Azure.DataApiBuilder.Core.Models;
using Azure.DataApiBuilder.Core.Services;
using Azure.DataApiBuilder.Core.Services.Cache;
using Azure.DataApiBuilder.Core.Services.MetadataProviders;
using Azure.DataApiBuilder.Service.GraphQLBuilder.Queries;
using Azure.DataApiBuilder.Service.Services;
using HotChocolate.Language;
using HotChocolate.Resolvers;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Azure.Cosmos;
using Newtonsoft.Json.Linq;
namespace Azure.DataApiBuilder.Core.Resolvers
{
/// <summary>
/// CosmosQueryEngine to execute queries against CosmosDb.
/// </summary>
public class CosmosQueryEngine : IQueryEngine
{
private readonly CosmosClientProvider _clientProvider;
private readonly IMetadataProviderFactory _metadataProviderFactory;
private readonly CosmosQueryBuilder _queryBuilder;
private readonly GQLFilterParser _gQLFilterParser;
private readonly IAuthorizationResolver _authorizationResolver;
private readonly RuntimeConfigProvider _runtimeConfigProvider;
private readonly DabCacheService _cache;
/// <summary>
/// Constructor
/// </summary>
public CosmosQueryEngine(
CosmosClientProvider clientProvider,
IMetadataProviderFactory metadataProviderFactory,
IAuthorizationResolver authorizationResolver,
GQLFilterParser gQLFilterParser,
RuntimeConfigProvider runtimeConfigProvider,
DabCacheService cache
)
{
_clientProvider = clientProvider;
_metadataProviderFactory = metadataProviderFactory;
_queryBuilder = new CosmosQueryBuilder();
_gQLFilterParser = gQLFilterParser;
_authorizationResolver = authorizationResolver;
_runtimeConfigProvider = runtimeConfigProvider;
_cache = cache;
}
/// <summary>
/// Executes the given IMiddlewareContext of the GraphQL query and
/// expecting a single Json back.
/// </summary>
public async Task<Tuple<JsonDocument, IMetadata>> ExecuteAsync(
IMiddlewareContext context,
IDictionary<string, object> parameters,
string dataSourceName)
{
// TODO: add support for join query against another container
// TODO: add support for TOP and Order-by push-down
ISqlMetadataProvider metadataStoreProvider = _metadataProviderFactory.GetMetadataProvider(dataSourceName);
CosmosQueryStructure structure = new(context, parameters, _runtimeConfigProvider, metadataStoreProvider, _authorizationResolver, _gQLFilterParser);
RuntimeConfig runtimeConfig = _runtimeConfigProvider.GetConfig();
string queryString = _queryBuilder.Build(structure);
QueryDefinition querySpec = new(queryString);
QueryRequestOptions queryRequestOptions = new();
CosmosClient client = _clientProvider.Clients[dataSourceName];
Container container = client.GetDatabase(structure.Database).GetContainer(structure.Container);
(string idValue, string partitionKeyValue) = await GetIdAndPartitionKey(context, parameters, container, structure, metadataStoreProvider);
foreach (KeyValuePair<string, DbConnectionParam> parameterEntry in structure.Parameters)
{
querySpec = querySpec.WithParameter(parameterEntry.Key, parameterEntry.Value.Value);
}
if (!string.IsNullOrEmpty(partitionKeyValue))
{
queryRequestOptions.PartitionKey = new PartitionKey(partitionKeyValue);
}
JObject executeQueryResult = null;
if (runtimeConfig.CanUseCache() && runtimeConfig.Entities[structure.EntityName].IsCachingEnabled)
{
StringBuilder dataSourceKey = new(dataSourceName);
// to support caching for paginated query adding continuation token in the datasource
dataSourceKey.Append(":");
dataSourceKey.Append(structure.Continuation);
DatabaseQueryMetadata queryMetadata = new(queryText: queryString, dataSource: dataSourceKey.ToString(), queryParameters: structure.Parameters);
executeQueryResult = await _cache.GetOrSetAsync<JObject>(async () => await ExecuteQueryAsync(structure, querySpec, queryRequestOptions, container, idValue, partitionKeyValue), queryMetadata, runtimeConfig.GetEntityCacheEntryTtl(entityName: structure.EntityName));
}
else
{
executeQueryResult = await ExecuteQueryAsync(structure, querySpec, queryRequestOptions, container, idValue, partitionKeyValue);
}
JsonDocument response = executeQueryResult != null ? JsonDocument.Parse(executeQueryResult.ToString()) : null;
return new Tuple<JsonDocument, IMetadata>(response, null);
}
/// <summary>
/// ExecuteQueryAsync Performs single partition and cross partition queries.
/// </summary>
/// <param name="structure">CosmosQueryStructure</param>
/// <param name="querySpec">QueryDefinition defining a Cosmos SQL Query</param>
/// <param name="queryRequestOptions">The Cosmos query request options</param>
/// <param name="container">CosmosDB Container</param>
/// <param name="idValue">Id param</param>
/// <param name="partitionKeyValue">PartitionKey Value</param>
/// <returns>JObject</returns>
private static async Task<JObject> ExecuteQueryAsync(
CosmosQueryStructure structure,
QueryDefinition querySpec,
QueryRequestOptions queryRequestOptions,
Container container,
string idValue,
string partitionKeyValue)
{
string requestContinuation = null;
if (structure.IsPaginated)
{
queryRequestOptions.MaxItemCount = (int?)structure.MaxItemCount;
requestContinuation = Base64Decode(structure.Continuation);
}
// If both partition key value and id value are provided, will execute single partition query
if (!string.IsNullOrEmpty(partitionKeyValue) && !string.IsNullOrEmpty(idValue))
{
return await QueryByIdAndPartitionKey(container, idValue, partitionKeyValue, structure.IsPaginated);
}
// If partition key value or id values are not provided, will execute cross partition query
using (FeedIterator<JObject> query = container.GetItemQueryIterator<JObject>(querySpec, requestContinuation, queryRequestOptions))
{
do
{
FeedResponse<JObject> page = await query.ReadNextAsync();
// For connection type, return first page result directly
if (structure.IsPaginated)
{
JArray jarray = new();
IEnumerator<JObject> enumerator = page.GetEnumerator();
while (enumerator.MoveNext())
{
JObject item = enumerator.Current;
jarray.Add(item);
}
string responseContinuation = page.ContinuationToken;
if (string.IsNullOrEmpty(responseContinuation))
{
responseContinuation = null;
}
JObject res = new(
new JProperty(QueryBuilder.PAGINATION_TOKEN_FIELD_NAME, Base64Encode(responseContinuation)),
new JProperty(QueryBuilder.HAS_NEXT_PAGE_FIELD_NAME, responseContinuation != null),
new JProperty(QueryBuilder.PAGINATION_FIELD_NAME, jarray));
return res;
}
if (page.Count > 0)
{
return page.First();
}
}
while (query.HasMoreResults);
}
// Return null when query gets no result back
return null;
}
/// <summary>
/// Executes the given IMiddlewareContext of the GraphQL query and
/// expecting a list of Json back.
/// </summary>
public async Task<Tuple<IEnumerable<JsonDocument>, IMetadata>> ExecuteListAsync(IMiddlewareContext context, IDictionary<string, object> parameters, string dataSourceName)
{
// TODO: fixme we have multiple rounds of serialization/deserialization JsomDocument/JObject
// TODO: add support for nesting
// TODO: add support for join query against another container
// TODO: add support for TOP and Order-by push-down
ISqlMetadataProvider metadataStoreProvider = _metadataProviderFactory.GetMetadataProvider(dataSourceName);
CosmosQueryStructure structure = new(context, parameters, _runtimeConfigProvider, metadataStoreProvider, _authorizationResolver, _gQLFilterParser);
CosmosClient client = _clientProvider.Clients[dataSourceName];
Container container = client.GetDatabase(structure.Database).GetContainer(structure.Container);
QueryDefinition querySpec = new(_queryBuilder.Build(structure));
foreach (KeyValuePair<string, DbConnectionParam> parameterEntry in structure.Parameters)
{
querySpec = querySpec.WithParameter(parameterEntry.Key, parameterEntry.Value.Value);
}
FeedIterator<JObject> resultSetIterator = container.GetItemQueryIterator<JObject>(querySpec);
List<JsonDocument> resultsAsList = new();
while (resultSetIterator.HasMoreResults)
{
FeedResponse<JObject> nextPage = await resultSetIterator.ReadNextAsync();
IEnumerator<JObject> enumerator = nextPage.GetEnumerator();
while (enumerator.MoveNext())
{
JObject item = enumerator.Current;
resultsAsList.Add(JsonDocument.Parse(item.ToString()));
}
}
return new Tuple<IEnumerable<JsonDocument>, IMetadata>(resultsAsList, null);
}
/// <inheritdoc />
public Task<JsonDocument> ExecuteAsync(FindRequestContext context)
{
throw new NotImplementedException();
}
/// <inheritdoc />
public Task<IActionResult> ExecuteAsync(StoredProcedureRequestContext context, string dataSourceName)
{
throw new NotImplementedException();
}
/// <inheritdoc />
public JsonElement ResolveObject(JsonElement element, IObjectField fieldSchema, ref IMetadata metadata)
{
return element;
}
/// <inheritdoc />
/// metadata is not used in this method, but it is required by the interface.
public object ResolveList(JsonElement array, IObjectField fieldSchema, ref IMetadata metadata)
{
IType listType = fieldSchema.Type;
// Is the List type nullable? [...]! vs [...]
if (listType.IsNonNullType())
{
listType = listType.InnerType().InnerType();
}
else
{
listType = listType.InnerType();
}
// Is the type of the list values nullable?
if (listType.IsNonNullType())
{
listType = listType.InnerType();
}
if (listType.IsObjectType())
{
return JsonSerializer.Deserialize<List<JsonElement>>(array);
}
return JsonSerializer.Deserialize(array, fieldSchema.RuntimeType);
}
/// <summary>
/// Query cosmos container using a single partition key, returns a single document.
/// </summary>
/// <param name="container"></param>
/// <param name="idValue"></param>
/// <param name="partitionKeyValue"></param>
/// <param name="IsPaginated"></param>
/// <returns></returns>
private static async Task<JObject> QueryByIdAndPartitionKey(Container container, string idValue, string partitionKeyValue, bool IsPaginated)
{
try
{
JObject item = await container.ReadItemAsync<JObject>(idValue, new PartitionKey(partitionKeyValue));
// If paginated, returning a Connection type document.
if (IsPaginated)
{
JObject res = new(
new JProperty(QueryBuilder.PAGINATION_TOKEN_FIELD_NAME, null),
new JProperty(QueryBuilder.HAS_NEXT_PAGE_FIELD_NAME, false),
new JProperty(QueryBuilder.PAGINATION_FIELD_NAME, new JArray { item }));
return res;
}
return item;
}
catch (CosmosException ex) when (ex.StatusCode == System.Net.HttpStatusCode.NotFound)
{
return null;
}
}
private static async Task<string> GetPartitionKeyPath(Container container, ISqlMetadataProvider metadataStoreProvider)
{
string partitionKeyPath = metadataStoreProvider.GetPartitionKeyPath(container.Database.Id, container.Id);
if (partitionKeyPath is not null)
{
return partitionKeyPath;
}
ContainerResponse properties = await container.ReadContainerAsync();
partitionKeyPath = properties.Resource.PartitionKeyPath;
metadataStoreProvider.SetPartitionKeyPath(container.Database.Id, container.Id, partitionKeyPath);
return partitionKeyPath;
}
#nullable enable
/// <summary>
/// Resolve partition key and id value from input parameters.
/// </summary>
/// <param name="context">Provide the information about variables and filters</param>
/// <param name="parameters">Contains argument information such as id, filter</param>
/// <param name="container">Container instance to get the container properties such as partition path</param>
/// <param name="structure">Fallback to get partition path information</param>
/// <param name="metadataStoreProvider">Set partition key path, fetched from container properties</param>
/// <returns></returns>
private static async Task<(string? idValue, string? partitionKeyValue)> GetIdAndPartitionKey(
IMiddlewareContext context,
IDictionary<string, object?> parameters,
Container container,
CosmosQueryStructure structure,
ISqlMetadataProvider metadataStoreProvider)
{
string? partitionKeyValue = null, idValue = null;
string partitionKeyPath = await GetPartitionKeyPath(container, metadataStoreProvider);
foreach (KeyValuePair<string, object?> parameterEntry in parameters)
{
// id and filter args can't exist at the same time
if (parameterEntry.Key == QueryBuilder.ID_FIELD_NAME)
{
// Set id value if id is passed in as an argument
idValue = parameterEntry.Value?.ToString();
}
else if (parameterEntry.Key == QueryBuilder.FILTER_FIELD_NAME)
{
// Mapping partitionKey and id value from filter object if filter keyword exists in args
partitionKeyValue = GetPartitionKeyValue(context, partitionKeyPath, parameterEntry.Value);
idValue = GetIdValue(context, parameterEntry.Value);
}
}
// If partition key was not found in the filter, then check if it's being passed in arguments
// Partition key is set in the structure object if the _partitionKeyValue keyword exists in args
if (string.IsNullOrEmpty(partitionKeyValue))
{
partitionKeyValue = structure.PartitionKeyValue;
}
return new(idValue, partitionKeyValue);
}
/// <summary>
/// This method is using `PartitionKeyPath` to find the partition key value from query input parameters, using recursion.
/// Example of `PartitionKeyPath` is `/character/id`.
/// </summary>
/// <param name="partitionKeyPath"></param>
/// <param name="parameter"></param>
/// <returns></returns>
#nullable enable
private static string? GetPartitionKeyValue(IMiddlewareContext context, string? partitionKeyPath, object? parameter)
{
if (parameter is null || partitionKeyPath is null)
{
return null;
}
string currentEntity = (partitionKeyPath.Split("/").Length > 1) ? partitionKeyPath.Split("/")[1] : string.Empty;
foreach (ObjectFieldNode item in (IList<ObjectFieldNode>)parameter)
{
if (partitionKeyPath == string.Empty
&& string.Equals(item.Name.Value, "eq", StringComparison.OrdinalIgnoreCase))
{
return ExecutionHelper.ExtractValueFromIValueNode(
item.Value,
context.Selection.Field.Arguments[QueryBuilder.FILTER_FIELD_NAME],
context.Variables)?.ToString();
}
if (partitionKeyPath != string.Empty
&& string.Equals(item.Name.Value, currentEntity, StringComparison.OrdinalIgnoreCase))
{
// Recursion to mapping next inner object
int index = partitionKeyPath.IndexOf(currentEntity);
string newPartitionKeyPath = partitionKeyPath[(index + currentEntity.Length)..partitionKeyPath.Length];
return GetPartitionKeyValue(context, newPartitionKeyPath, item.Value.Value);
}
}
return null;
}
/// <summary>
/// Parsing id field value from input parameter
/// </summary>
/// <param name="parameter"></param>
/// <returns></returns>
private static string? GetIdValue(IMiddlewareContext context, object? parameter)
{
if (parameter != null)
{
foreach (ObjectFieldNode item in (IList<ObjectFieldNode>)parameter)
{
if (string.Equals(item.Name.Value, "id", StringComparison.OrdinalIgnoreCase))
{
IList<ObjectFieldNode>? idValueObj = (IList<ObjectFieldNode>?)item.Value.Value;
ObjectFieldNode? itemToResolve = idValueObj?.FirstOrDefault(x => x.Name.Value == "eq");
if (itemToResolve is null)
{
return null;
}
return ExecutionHelper.ExtractValueFromIValueNode(
itemToResolve.Value,
context.Selection.Field.Arguments[QueryBuilder.FILTER_FIELD_NAME],
context.Variables)?
.ToString();
}
}
}
return null;
}
private static string? Base64Encode(string plainText)
{
if (plainText == default)
{
return null;
}
byte[] plainTextBytes = Encoding.UTF8.GetBytes(plainText);
return Convert.ToBase64String(plainTextBytes);
}
private static string? Base64Decode(string base64EncodedData)
{
if (base64EncodedData == default)
{
return null;
}
byte[] base64EncodedBytes = Convert.FromBase64String(base64EncodedData);
return Encoding.UTF8.GetString(base64EncodedBytes);
}
}
}