AzureCosmosDB/csharp/DocumentVectorPipelineFunctions/CosmosDBClientWrapper.cs (151 lines of code) (raw):
using System.Globalization;
using System.Net;
using System.Text.Json.Serialization;
using Microsoft.Azure.Cosmos;
using Microsoft.Extensions.Logging;
using OpenAI.Embeddings;
using Container = Microsoft.Azure.Cosmos.Container;
namespace DocumentVectorPipelineFunctions;
internal class CosmosDBClientWrapper
{
private readonly CosmosClient client;
private readonly ILogger logger;
private Container? container;
private static CosmosDBClientWrapper? instance;
private const int MaxRetryCount = 100;
public static async ValueTask<CosmosDBClientWrapper> CreateInstance(CosmosClient client, ILogger logger)
{
if (instance != null)
{
return instance;
}
var curInstance = new CosmosDBClientWrapper(client, logger);
await curInstance.GetOrCreateDatabaseAndContainerAsync();
instance = curInstance;
return instance;
}
public async Task UpsertDocumentsAsync(string fileUri, List<TextChunk> chunks, EmbeddingCollection embeddings, CancellationToken cancellationToken)
{
if (this.container == null)
{
throw new InvalidOperationException("Container is not initialized.");
}
var upsertTasks = new List<Task<ItemResponse<DocumentChunk>>>();
for (var index = 0; index < chunks.Count; index++)
{
var documentChunk = new DocumentChunk
{
ChunkId = chunks[index].ChunkNumber.ToString("d", CultureInfo.InvariantCulture),
DocumentUrl = fileUri,
Embedding = embeddings[index].Vector,
ChunkText = chunks[index].Text,
};
upsertTasks.Add(this.UpsertDocumentWithRetryAsync(documentChunk, CosmosDBClientWrapper.MaxRetryCount, cancellationToken));
}
try
{
await Task.WhenAll(upsertTasks);
}
catch (AggregateException aggEx)
{
foreach (var item in aggEx.InnerExceptions)
{
if (item is CosmosException cosmosException)
{
this.LogHeaders(cosmosException.Headers);
}
}
throw;
}
}
private async Task<ItemResponse<DocumentChunk>> UpsertDocumentWithRetryAsync(
DocumentChunk document,
int maxRetryAttempts,
CancellationToken cancellationToken)
{
if (this.container == null)
{
throw new InvalidOperationException("Container is not initialized.");
}
var retryCount = 0;
while (retryCount < maxRetryAttempts)
{
try
{
return await this.container.UpsertItemAsync(document, cancellationToken: cancellationToken);
}
catch (CosmosException ex) when (ex.StatusCode == HttpStatusCode.TooManyRequests)
{
retryCount++;
await Task.Delay(ex.RetryAfter.GetValueOrDefault(), cancellationToken);
}
catch (Exception ex)
{
this.logger.LogError("An error occurred while upserting document with ID {chunkId}: {exceptionMessage}", document.ChunkId, ex.Message);
throw;
}
}
throw new Exception($"Max retry attempts reached for document with ID {document.ChunkId}. Operation failed.");
}
private CosmosDBClientWrapper(CosmosClient client, ILogger logger)
{
this.client = client;
this.logger = logger;
}
private async Task GetOrCreateDatabaseAndContainerAsync()
{
var dbResponse = await this.client.CreateDatabaseIfNotExistsAsync("semantic_search_db");
var indexingPolicy = new IndexingPolicy()
{
// TODO: Include Full-Text Index for the chunk_text property.
VectorIndexes =
[
new VectorIndexPath
{
Path = "/embedding",
Type = VectorIndexType.QuantizedFlat,
}
]
};
indexingPolicy.ExcludedPaths.Add(new ExcludedPath { Path = "/embedding/*" });
var containerResponse = await dbResponse.Database.CreateContainerIfNotExistsAsync(new ContainerProperties
{
Id = "doc_search_container",
PartitionKeyPath = "/document_url",
IndexingPolicy = indexingPolicy,
VectorEmbeddingPolicy = new(
[
new Microsoft.Azure.Cosmos.Embedding
{
DataType = VectorDataType.Float32,
Dimensions = 1536,
DistanceFunction = DistanceFunction.Cosine,
Path = "/embedding"
},
]),
});
this.container = containerResponse.Container;
if (containerResponse.StatusCode != System.Net.HttpStatusCode.OK)
{
this.LogHeaders(containerResponse.Headers);
}
}
private void LogHeaders(Headers headers)
{
using var scope = this.logger.BeginScope("Created a container.");
foreach (var headerName in headers.AllKeys())
{
this.logger.LogWarning("Header: {header}, Value: '{value}'", headerName, headers[headerName]);
}
}
private class DocumentChunk
{
[JsonPropertyName("id")]
public string? ChunkId { get; init; }
[JsonPropertyName("document_url")]
public string? DocumentUrl { get; init; }
[JsonPropertyName("chunk_text")]
public string? ChunkText { get; init; }
[JsonPropertyName("embedding")]
public ReadOnlyMemory<float> Embedding { get; init; }
}
}