Elastic.SemanticKernel.Connectors.Elasticsearch/ElasticsearchVectorStoreRecordCollection.cs (263 lines of code) (raw):

// Licensed to Elasticsearch B.V under one or more agreements. // Elasticsearch B.V licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System; using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; using System.Text.Json.Nodes; using System.Threading; using System.Threading.Tasks; using Elastic.Clients.Elasticsearch; using Elastic.Clients.Elasticsearch.QueryDsl; using Elastic.Transport; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Data; namespace Elastic.SemanticKernel.Connectors.Elasticsearch; #pragma warning disable CA1711 // Identifiers should not have incorrect suffix /// <summary> /// Service for storing and retrieving vector records, that uses Elasticsearch as the underlying storage. /// </summary> /// <typeparam name="TRecord">The data model to use for adding, updating and retrieving data from storage.</typeparam> public sealed class ElasticsearchVectorStoreRecordCollection<TRecord> : IVectorStoreRecordCollection<string, TRecord> #pragma warning restore CA1711 // Identifiers should not have incorrect suffix { /// <summary>The name of this database for telemetry purposes.</summary> private const string DatabaseName = "Elasticsearch"; /// <summary>A set of types that a key on the provided model may have.</summary> private static readonly HashSet<Type> SupportedKeyTypes = [ typeof(string) ]; /// <summary>The default options for vector search.</summary> private static readonly VectorSearchOptions DefaultVectorSearchOptions = new(); /// <summary>Elasticsearch client that can be used to manage the collections and points in an Elasticsearch store.</summary> private readonly MockableElasticsearchClient _elasticsearchClient; /// <summary>Optional configuration options for this class.</summary> private readonly ElasticsearchVectorStoreRecordCollectionOptions<TRecord> _options; /// <summary>A helper to access property information for the current data model and record definition.</summary> private readonly VectorStoreRecordPropertyReader _propertyReader; /// <summary>A mapping from <see cref="VectorStoreRecordDefinition" /> to storage model property name.</summary> private readonly Dictionary<VectorStoreRecordProperty, string> _propertyToStorageName; /// <summary>TODO: TBC</summary> private readonly IVectorStoreRecordMapper<TRecord, (string? id, JsonObject document)> _mapper; /// <summary> /// Initializes a new instance of the <see cref="ElasticsearchVectorStoreRecordCollection{TRecord}" /> class. /// </summary> /// <param name="elasticsearchClient"> /// Elasticsearch client that can be used to manage the collections and points in an /// Elasticsearch store. /// </param> /// <param name="collectionName"> /// The name of the collection that this /// <see cref="ElasticsearchVectorStoreRecordCollection{TRecord}" /> will access. /// </param> /// <param name="options">Optional configuration options for this class.</param> /// <exception cref="ArgumentNullException">Thrown if the <paramref name="elasticsearchClient" /> is null.</exception> /// <exception cref="ArgumentException">Thrown for any misconfigured options.</exception> public ElasticsearchVectorStoreRecordCollection(ElasticsearchClient elasticsearchClient, string collectionName, ElasticsearchVectorStoreRecordCollectionOptions<TRecord>? options = null) : this(new MockableElasticsearchClient(elasticsearchClient), collectionName, options) { } /// <summary> /// Initializes a new instance of the <see cref="ElasticsearchVectorStoreRecordCollection{TRecord}" /> class. /// </summary> /// <param name="elasticsearchClient"> /// Elasticsearch client that can be used to manage the collections and points in an /// Elasticsearch store. /// </param> /// <param name="collectionName"> /// The name of the collection that this /// <see cref="ElasticsearchVectorStoreRecordCollection{TRecord}" /> will access. /// </param> /// <param name="options">Optional configuration options for this class.</param> /// <exception cref="ArgumentNullException">Thrown if the <paramref name="elasticsearchClient" /> is null.</exception> /// <exception cref="ArgumentException">Thrown for any misconfigured options.</exception> internal ElasticsearchVectorStoreRecordCollection(MockableElasticsearchClient elasticsearchClient, string collectionName, ElasticsearchVectorStoreRecordCollectionOptions<TRecord>? options = null) { // Verify. Verify.NotNull(elasticsearchClient); Verify.NotNullOrWhiteSpace(collectionName); VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), false /* TODO: options?.CustomMapper is not null */, SupportedKeyTypes); VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); // Assign. _elasticsearchClient = elasticsearchClient; CollectionName = collectionName; _options = options ?? new ElasticsearchVectorStoreRecordCollectionOptions<TRecord>(); _propertyReader = new VectorStoreRecordPropertyReader( typeof(TRecord), _options.VectorStoreRecordDefinition, new VectorStoreRecordPropertyReaderOptions { RequiresAtLeastOneVector = false, SupportsMultipleKeys = false, SupportsMultipleVectors = true }); if (typeof(TRecord) == typeof(VectorStoreGenericDataModel<string>)) { // Prioritize the user provided `StoragePropertyName` or fall-back to using the `DefaultFieldNameInferrer` // function of the Elasticsearch client which by default redirects to the // `JsonSerializerOptions.PropertyNamingPolicy.Convert() method. _propertyToStorageName = _propertyReader.Properties.ToDictionary(k => k, v => v.StoragePropertyName ?? _elasticsearchClient.ElasticsearchClient.ElasticsearchClientSettings.DefaultFieldNameInferrer(v.DataModelPropertyName)); _mapper = (new ElasticsearchGenericDataModelMapper(_propertyToStorageName, elasticsearchClient.ElasticsearchClient.ElasticsearchClientSettings) as IVectorStoreRecordMapper<TRecord, (string id, JsonObject document)>)!; } else { // Use the built-in property name inference of the Elasticsearch client. The default implementation // prioritizes `JsonPropertyName` attributes and falls-back to the `DefaultFieldNameInferrer` function, // which by default redirects to the `JsonSerializerOptions.PropertyNamingPolicy.Convert() method. _propertyToStorageName = _propertyReader.Properties.ToDictionary(k => k, v => { var info = _propertyReader.KeyPropertiesInfo.FirstOrDefault(x => string.Equals(x.Name, v.DataModelPropertyName, StringComparison.Ordinal)) ?? _propertyReader.VectorPropertiesInfo.FirstOrDefault(x => string.Equals(x.Name, v.DataModelPropertyName, StringComparison.Ordinal)) ?? _propertyReader.DataPropertiesInfo.FirstOrDefault(x => string.Equals(x.Name, v.DataModelPropertyName, StringComparison.Ordinal)); if (info is null) { throw new InvalidOperationException("unreachable"); } return _elasticsearchClient.ElasticsearchClient.Infer.PropertyName(info); }); _mapper = new ElasticsearchDataModelMapper<TRecord>(_propertyToStorageName, elasticsearchClient.ElasticsearchClient.ElasticsearchClientSettings); } // Validate property types. _propertyReader.VerifyKeyProperties(SupportedKeyTypes); VectorStoreRecordPropertyVerification.VerifyPropertyTypes(_propertyReader.VectorProperties, [typeof(ReadOnlyMemory<float>), typeof(ReadOnlyMemory<float>?)], [typeof(float)], "Vector"); } /// <inheritdoc /> public string CollectionName { get; } /// <inheritdoc /> public Task<bool> CollectionExistsAsync(CancellationToken cancellationToken = default) { return RunOperationAsync( "indices.exists", () => _elasticsearchClient.IndexExistsAsync(CollectionName, cancellationToken)); } /// <inheritdoc /> public Task CreateCollectionAsync(CancellationToken cancellationToken = default) { var propertyMappings = ElasticsearchVectorStoreCollectionCreateMapping.BuildPropertyMappings(_propertyReader, _propertyToStorageName); return RunOperationAsync( "indices.create", () => _elasticsearchClient.CreateIndexAsync(CollectionName, propertyMappings, cancellationToken)); } /// <inheritdoc /> public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) { if (!await CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) { await CreateCollectionAsync(cancellationToken).ConfigureAwait(false); } } /// <inheritdoc /> public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) { return RunOperationAsync( "indices.delete", () => _elasticsearchClient.DeleteIndexAsync(CollectionName, cancellationToken)); } /// <inheritdoc /> public async Task<TRecord?> GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { // TODO: Handle options var storageModel = await RunOperationAsync( "get", () => _elasticsearchClient.GetDocumentAsync(CollectionName, key, cancellationToken)) .ConfigureAwait(false); if (!storageModel.HasValue) { return default; } var record = VectorStoreErrorHandler.RunModelConversion(DatabaseName, CollectionName, "get", () => _mapper.MapFromStorageToDataModel(storageModel.Value, new StorageToDataModelMapperOptions())); return record; } /// <inheritdoc /> public async IAsyncEnumerable<TRecord> GetBatchAsync(IEnumerable<string> keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { // TODO: Use mget endpoint // TODO: Handle options Verify.NotNull(keys); foreach (var key in keys) { var record = await GetAsync(key, options, cancellationToken).ConfigureAwait(false); if (record is null) { continue; } yield return record; } } /// <inheritdoc /> public async Task DeleteAsync(string key, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) { // TODO: Handle options Verify.NotNullOrWhiteSpace(key); await RunOperationAsync( "delete", () => _elasticsearchClient.DeleteDocumentAsync(CollectionName, key, cancellationToken)) .ConfigureAwait(false); } /// <inheritdoc /> public async Task DeleteBatchAsync(IEnumerable<string> keys, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) { // TODO: Use _bulk endpoint // TODO: Handle options Verify.NotNull(keys); foreach (var key in keys) { await DeleteAsync(key, options, cancellationToken).ConfigureAwait(false); } } /// <inheritdoc /> public async Task<string> UpsertAsync(TRecord record, UpsertRecordOptions? options = null, CancellationToken cancellationToken = default) { // TODO: Handle options Verify.NotNull(record); var storageModel = VectorStoreErrorHandler.RunModelConversion(DatabaseName, CollectionName, "index", () => _mapper.MapFromDataToStorageModel(record)); var id = await RunOperationAsync( "index", () => _elasticsearchClient.IndexDocumentAsync(CollectionName, storageModel.id!, storageModel.document, cancellationToken)) .ConfigureAwait(false); return id; } /// <inheritdoc /> public async IAsyncEnumerable<string> UpsertBatchAsync(IEnumerable<TRecord> records, UpsertRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { // TODO: Use _bulk endpoint // TODO: Handle options Verify.NotNull(records); foreach (var record in records) { yield return await UpsertAsync(record, options, cancellationToken).ConfigureAwait(false); } } public async Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); // Validate inputs. if (_propertyReader.FirstVectorPropertyName is null) { throw new InvalidOperationException("The collection does not have any vector fields, so vector search is not possible."); } var floatVector = vector switch { ICollection<float> v => v, IEnumerable<float> v => [.. v], ReadOnlyMemory<float> v => v.ToArray(), _ => throw new NotSupportedException($"The provided vector type {vector.GetType().FullName} is not supported by the Elasticsearch connector.") }; var searchOptions = options ?? DefaultVectorSearchOptions; // Specify the vector name. var vectorProperty = _propertyReader.VectorProperties[0]; if (!string.IsNullOrWhiteSpace(searchOptions.VectorPropertyName)) { vectorProperty = _propertyReader.VectorProperties.First(x => string.Equals(x.DataModelPropertyName, searchOptions.VectorPropertyName, StringComparison.Ordinal)); } // Build search query. var knnQuery = new KnnQuery { Field = _propertyToStorageName[vectorProperty]!, QueryVector = floatVector.ToArray() }; var filterQueries = ElasticsearchVectorStoreCollectionSearchMapping.BuildFilter(searchOptions.Filter, _propertyToStorageName); if (filterQueries.Count != 0) { knnQuery.Filter = filterQueries; } // Execute search query. var result = await RunOperationAsync( "search", () => _elasticsearchClient.SearchAsync( CollectionName, Query.Knn(knnQuery), searchOptions.Skip, searchOptions.Top, cancellationToken: cancellationToken)) .ConfigureAwait(false); // Map results. var mappedResults = result.hits.Select(x => new VectorSearchResult<TRecord>( VectorStoreErrorHandler.RunModelConversion(DatabaseName, CollectionName, "search", () => _mapper.MapFromStorageToDataModel((x.id, x.document), new StorageToDataModelMapperOptions())), x.score ) ); return new VectorSearchResults<TRecord>(mappedResults.ToAsyncEnumerable()) { TotalCount = searchOptions.IncludeTotalCount ? (result.total is var total and >= 0) ? total : null : null }; } /// <summary> /// Run the given operation and wrap any <see cref="TransportException" /> with /// <see cref="VectorStoreOperationException" />."/> /// </summary> /// <param name="operationName">The type of database operation being run.</param> /// <param name="operation">The operation to run.</param> /// <returns>The result of the operation.</returns> private async Task RunOperationAsync(string operationName, Func<Task> operation) { try { await operation.Invoke().ConfigureAwait(false); } catch (TransportException ex) { throw new VectorStoreOperationException("Call to vector store failed.", ex) { VectorStoreType = DatabaseName, CollectionName = CollectionName, OperationName = operationName }; } } /// <summary> /// Run the given operation and wrap any <see cref="TransportException" /> with /// <see cref="VectorStoreOperationException" />."/> /// </summary> /// <typeparam name="T">The response type of the operation.</typeparam> /// <param name="operationName">The type of database operation being run.</param> /// <param name="operation">The operation to run.</param> /// <returns>The result of the operation.</returns> private async Task<T> RunOperationAsync<T>(string operationName, Func<Task<T>> operation) { try { return await operation.Invoke().ConfigureAwait(false); } catch (TransportException ex) { throw new VectorStoreOperationException("Call to vector store failed.", ex) { VectorStoreType = DatabaseName, CollectionName = CollectionName, OperationName = operationName }; } } }