AzureSQL/csharp/DocumentVectorPipelineFunctions/BlobTriggerFunction.cs (186 lines of code) (raw):

using System.ClientModel; using System.Data; using System.Globalization; using System.Net; using System.Text.Json; using Azure; using Azure.AI.FormRecognizer.DocumentAnalysis; using Azure.Core; using Azure.Identity; using Azure.Storage.Blobs; using Dapper; using Microsoft.Azure.Functions.Worker; using Microsoft.Data.SqlClient; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Logging; using OpenAI.Embeddings; namespace DocumentVectorPipelineFunctions; public class BlobTriggerFunction( IConfiguration configuration, DocumentAnalysisClient documentAnalysisClient, ILoggerFactory loggerFactory, EmbeddingClient embeddingClient) { private const string AzureOpenAIModelDeploymentDimensionsName = "AzureOpenAIModelDimensions"; private const string SqlConnectionString = "SqlConnectionString"; private readonly ILogger _logger = loggerFactory.CreateLogger<BlobTriggerFunction>(); private readonly string TableSchemaName = Environment.GetEnvironmentVariable("DocumentTableSchema") ?? "dbo"; private readonly string TableName = Environment.GetEnvironmentVariable("DocumentTableName") ?? "documents"; string? managedIdentityClientId = Environment.GetEnvironmentVariable("AzureManagedIdentityClientId"); private static readonly int DefaultDimensions = 1536; private const int MaxRetryCount = 100; private const int RetryDelay = 10 * 1000; // 10 seconds private const int MaxBatchSize = 10; private const int MaxDegreeOfParallelism = 50; private int embeddingDimensions = DefaultDimensions; [Function("BlobTriggerFunction")] public async Task Run([BlobTrigger("documents/{name}", Connection = "AzureBlobStorageAccConnectionString")] BlobClient blobClient) { this._logger.LogInformation("Starting processing of blob name: '{name}'", blobClient.Name); if (await blobClient.ExistsAsync()) { await this.HandleBlobCreateEventAsync(blobClient); } else { await this.HandleBlobDeleteEventAsync(blobClient); } this._logger.LogInformation("Finished processing of blob name: '{name}'", blobClient.Name); } private async Task HandleBlobCreateEventAsync(BlobClient blobClient) { embeddingDimensions = configuration.GetValue<int>(AzureOpenAIModelDeploymentDimensionsName, DefaultDimensions); var connectionString = configuration.GetValue<string>(SqlConnectionString); _logger.LogInformation("Using OpenAI model dimensions: '{embeddingDimensions}'.", embeddingDimensions); _logger.LogInformation("Analyzing document using DocumentAnalyzerService from blobUri: '{blobUri}' using layout: {layout}", blobClient.Name, "prebuilt-read"); using MemoryStream memoryStream = new MemoryStream(); await blobClient.DownloadToAsync(memoryStream); memoryStream.Seek(0, SeekOrigin.Begin); var operation = await documentAnalysisClient.AnalyzeDocumentAsync( WaitUntil.Completed, "prebuilt-read", memoryStream); var result = operation.Value; _logger.LogInformation("Extracted content from '{name}', # pages {pageCount}", blobClient.Name, result.Pages.Count); var textChunks = TextChunker.FixedSizeChunking(result); var listOfBatches = new List<List<TextChunk>>(); int totalChunksCount = 0; var batchChunkTexts = new List<TextChunk>(MaxBatchSize); for (int i = 0; i <= textChunks.Count(); i++) { if (i == textChunks.Count()) { if (batchChunkTexts.Count > 0) { listOfBatches.Add(new List<TextChunk>(batchChunkTexts)); } batchChunkTexts.Clear(); break; } batchChunkTexts.Add(textChunks.ElementAt(i)); totalChunksCount++; if (batchChunkTexts.Count >= MaxBatchSize) { listOfBatches.Add(new List<TextChunk>(batchChunkTexts)); batchChunkTexts.Clear(); } } _logger.LogInformation("Processing list of batches in parallel, total batches: {listSize}, chunks count: {chunksCount}", listOfBatches.Count(), totalChunksCount); await EnsureDocumentTableExistsAsync(connectionString); await Parallel.ForEachAsync(listOfBatches, new ParallelOptions { MaxDegreeOfParallelism = MaxDegreeOfParallelism }, async (batchChunkTexts, cancellationToken) => { _logger.LogInformation("Processing batch of size: {batchSize}", batchChunkTexts.Count); if (batchChunkTexts.Count > 0) { var embeddings = await GenerateEmbeddingsWithRetryAsync(batchChunkTexts); _logger.LogInformation("Embeddings generated: {0}", embeddings.Count); if (embeddings.Count > 0) { // Save into Azure SQL _logger.LogInformation("Begin Saving data in Azure SQL"); for (int index = 0; index < batchChunkTexts.Count; index++) { using (var connection = new SqlConnection(connectionString)) { string SanitizedName = SantizeDatabaseObjectName(TableSchemaName) + "." + SantizeDatabaseObjectName(TableName); string insertQuery = $@"INSERT INTO {SanitizedName} (ChunkId, DocumentUrl, Embedding, ChunkText, PageNumber) VALUES (@ChunkId, @DocumentUrl, @Embedding, @ChunkText, @PageNumber);"; var doc = new Document() { ChunkId = batchChunkTexts[index].ChunkNumber, DocumentUrl = blobClient.Uri.AbsoluteUri, Embedding = JsonSerializer.Serialize(embeddings[index].Vector), ChunkText = batchChunkTexts[index].Text, PageNumber = batchChunkTexts[index].PageNumberIfKnown, }; //connection.AccessToken = token.Token; var result = connection.Execute(insertQuery, doc); } } _logger.LogInformation("End Saving data in Azure SQL"); } } }); _logger.LogInformation("Finished processing blob {name}, total chunks processed {count}.", blobClient.Name, totalChunksCount); } private async Task<EmbeddingCollection> GenerateEmbeddingsWithRetryAsync(IEnumerable<TextChunk> batchChunkTexts) { EmbeddingGenerationOptions embeddingGenerationOptions = new() { Dimensions = embeddingDimensions }; int retryCount = 0; while (retryCount < MaxRetryCount) { try { return await embeddingClient.GenerateEmbeddingsAsync(batchChunkTexts.Select(p => p.Text).ToList(), embeddingGenerationOptions); } catch (ClientResultException ex) { if (ex.Status is ((int)HttpStatusCode.TooManyRequests) or ((int)HttpStatusCode.Unauthorized)) { if (retryCount >= MaxRetryCount) { throw new Exception($"Max retry attempts reached generating embeddings with exception: {ex}."); } retryCount++; await Task.Delay(RetryDelay); } else { throw new Exception($"Failed to generate embeddings with error: {ex}."); } } } throw new Exception($"Failed to generate embeddings after retrying for ${MaxRetryCount} times."); } private async Task HandleBlobDeleteEventAsync(BlobClient blobClient) { // TODO - Implement me :) _logger.LogInformation("Handling delete event for blob name {name}.", blobClient.Name); await Task.Delay(1); } private string SantizeDatabaseObjectName(string name) { string santized = name.Trim(); if (santized.StartsWith('[') && santized.EndsWith(']')) return santized; else return "[" + santized + "]"; } private async Task EnsureDocumentTableExistsAsync(string connectionString) { _logger.LogInformation("Creating table if it does not exist yet..."); string SanitizedName = SantizeDatabaseObjectName(TableSchemaName) + "." + SantizeDatabaseObjectName(TableName); _logger.LogInformation("Document Table: {0}", SanitizedName); string createDocumentTableScript = $@" IF NOT EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = PARSENAME('{SanitizedName}', 1) AND TABLE_SCHEMA = PARSENAME('{SanitizedName}', 2)) BEGIN CREATE TABLE {SanitizedName} ( [Id] INT IDENTITY(1,1) PRIMARY KEY NOT NULL, [ChunkId] INT NULL, [DocumentUrl] VARCHAR(1000) NULL, [Embedding] VECTOR(1536) NULL, [ChunkText] VARCHAR(MAX) NULL, [PageNumber] INT NULL ); END"; using (var connection = new SqlConnection(connectionString)) { //connection.AccessToken = token.Token; await connection.ExecuteAsync(createDocumentTableScript); } } }