private async Task UpsertRowsAsync()

in src/SqlAsyncCollector.cs [171:315]


        private async Task UpsertRowsAsync(IList<T> rows, SqlAttribute attribute, IConfiguration configuration)
        {
            var upsertRowsAsyncSw = Stopwatch.StartNew();
            using (SqlConnection connection = BuildConnection(attribute.ConnectionStringSetting, configuration))
            {
                await connection.OpenAsync();
                this._serverProperties = await GetServerTelemetryProperties(connection, this._logger, CancellationToken.None);
                Dictionary<TelemetryPropertyName, string> props = connection.AsConnectionProps(this._serverProperties);

                string fullTableName = attribute.CommandText;

                // Include the connection string hash as part of the key in case this customer has the same table in two different Sql Servers
                string cacheKey = $"{connection.ConnectionString.GetHashCode()}-{fullTableName}";

                ObjectCache cachedTables = MemoryCache.Default;

                int timeout = AZ_FUNC_TABLE_INFO_CACHE_TIMEOUT_MINUTES;
                string timeoutEnvVar = Environment.GetEnvironmentVariable("AZ_FUNC_TABLE_INFO_CACHE_TIMEOUT_MINUTES");
                if (!string.IsNullOrEmpty(timeoutEnvVar))
                {
                    if (int.TryParse(timeoutEnvVar, NumberStyles.Integer, CultureInfo.InvariantCulture, out timeout))
                    {
                        this._logger.LogDebug($"Overriding default table info cache timeout with new value {timeout}");
                    }
                    else
                    {
                        timeout = AZ_FUNC_TABLE_INFO_CACHE_TIMEOUT_MINUTES;
                    }
                }

                if (!(cachedTables[cacheKey] is TableInformation tableInfo))
                {
                    TelemetryInstance.TrackEvent(TelemetryEventName.TableInfoCacheMiss, props);
                    // set the columnNames for supporting T as JObject since it doesn't have columns in the member info.
                    tableInfo = TableInformation.RetrieveTableInformation(connection, fullTableName, this._logger, this._serverProperties);
                    var policy = new CacheItemPolicy
                    {
                        // Re-look up the primary key(s) after timeout (default timeout is 10 minutes)
                        AbsoluteExpiration = DateTimeOffset.Now.AddMinutes(timeout)
                    };

                    cachedTables.Set(cacheKey, tableInfo, policy);
                }
                else
                {
                    TelemetryInstance.TrackEvent(TelemetryEventName.TableInfoCacheHit, props);
                }

                IEnumerable<string> extraProperties = GetExtraProperties(tableInfo.Columns, rows.First());
                if (extraProperties.Any())
                {
                    string message = $"The following properties in {typeof(T)} do not exist in the table {fullTableName}: {string.Join(", ", extraProperties.ToArray())}.";
                    var ex = new InvalidOperationException(message);
                    TelemetryInstance.TrackException(TelemetryErrorName.PropsNotExistOnTable, ex, props);
                    throw ex;
                }

                IEnumerable<string> columnNamesFromItem = GetColumnNamesFromItem(rows.First());
                IEnumerable<string> unsupportedColumns = columnNamesFromItem.Where(prop => UnsupportedTypes.Contains(tableInfo.Columns[prop], StringComparer.OrdinalIgnoreCase));
                if (unsupportedColumns.Any())
                {
                    string message = $"The type(s) of the following column(s) are not supported: {string.Join(", ", unsupportedColumns.ToArray())}. See https://github.com/Azure/azure-functions-sql-extension#output-bindings for more details.";
                    throw new InvalidOperationException(message);
                }

                IEnumerable<string> bracketedColumnNamesFromItem = columnNamesFromItem
                    .Where(prop => !tableInfo.PrimaryKeys.Any(k => k.IsIdentity && string.Equals(k.Name, prop, StringComparison.Ordinal))) // Skip any identity columns, those should never be updated
                    .Select(prop => prop.AsBracketQuotedString());
                if (!bracketedColumnNamesFromItem.Any())
                {
                    string message = $"No property values found in item to upsert. If using query parameters, ensure that the casing of the parameter names and the property names match.";
                    var ex = new InvalidOperationException(message);
                    throw ex;
                }

                var table = new SqlObject(fullTableName);

                IEnumerable<string> objectColumnNames = GetColumnNamesFromItem(rows.First());
                IEnumerable<string> primaryKeysFromObject = objectColumnNames.Where(f => tableInfo.PrimaryKeys.Any(k => string.Equals(k.Name, f, StringComparison.Ordinal)));
                IEnumerable<PrimaryKey> missingPrimaryKeysFromItem = tableInfo.PrimaryKeys
                    .Where(k => !primaryKeysFromObject.Contains(k.Name));
                // If none of the primary keys are an identity column or have a default value then we require that all primary keys be present in the POCO so we can
                // generate the MERGE statement correctly
                if (!tableInfo.HasIdentityColumnPrimaryKeys && !tableInfo.HasDefaultColumnPrimaryKeys && missingPrimaryKeysFromItem.Any())
                {
                    string message = $"All primary keys for SQL table {table} need to be found in '{typeof(T)}.' Missing primary keys: [{string.Join(",", missingPrimaryKeysFromItem)}]";
                    var ex = new InvalidOperationException(message);
                    TelemetryInstance.TrackException(TelemetryErrorName.MissingPrimaryKeys, ex, connection.AsConnectionProps(this._serverProperties));
                    throw ex;
                }
                // If any identity columns or columns with default values aren't included in the object then we have to generate a basic insert since the merge statement expects all primary key
                // columns to exist. (the merge statement can handle nullable columns though if those exist)
                QueryType queryType = (tableInfo.HasIdentityColumnPrimaryKeys || tableInfo.HasDefaultColumnPrimaryKeys) && missingPrimaryKeysFromItem.Any() ? QueryType.Insert : QueryType.Merge;
                string mergeOrInsertQuery = queryType == QueryType.Insert ? TableInformation.GetInsertQuery(table, bracketedColumnNamesFromItem) :
                    TableInformation.GetMergeQuery(tableInfo.PrimaryKeys, table, bracketedColumnNamesFromItem);

                var transactionSw = Stopwatch.StartNew();
                int batchSize = 1000;
                SqlTransaction transaction = connection.BeginTransaction();
                try
                {
                    SqlCommand command = connection.CreateCommand();
                    command.Connection = connection;
                    command.Transaction = transaction;
                    SqlParameter par = command.Parameters.Add(RowDataParameter, SqlDbType.NVarChar, -1);
                    int batchCount = 0;
                    var commandSw = Stopwatch.StartNew();
                    foreach (IEnumerable<T> batch in rows.Batch(batchSize))
                    {
                        batchCount++;
                        GenerateDataQueryForMerge(tableInfo, batch, out string newDataQuery, out string rowData);
                        command.CommandText = $"{newDataQuery} {mergeOrInsertQuery};";
                        par.Value = rowData;
                        await command.ExecuteNonQueryAsyncWithLogging(this._logger, CancellationToken.None);
                    }
                    transaction.Commit();
                    transactionSw.Stop();
                    upsertRowsAsyncSw.Stop();
                    var measures = new Dictionary<TelemetryMeasureName, double>()
                {
                    { TelemetryMeasureName.BatchCount, batchCount },
                    { TelemetryMeasureName.TransactionDurationMs, transactionSw.ElapsedMilliseconds },
                    { TelemetryMeasureName.CommandDurationMs, commandSw.ElapsedMilliseconds },
                    { TelemetryMeasureName.BatchSize, batchSize },
                    { TelemetryMeasureName.NumRows, rows.Count }
                };
                    TelemetryInstance.TrackEvent(TelemetryEventName.Upsert, props, measures);
                }
                catch (Exception ex)
                {
                    try
                    {
                        TelemetryInstance.TrackException(TelemetryErrorName.Upsert, ex, props);
                        transaction.Rollback();
                    }
                    catch (Exception ex2)
                    {
                        TelemetryInstance.TrackException(TelemetryErrorName.UpsertRollback, ex2, props);
                        string message2 = $"Encountered exception during upsert and rollback.";
                        throw new AggregateException(message2, new List<Exception> { ex, ex2 });
                    }
                    throw new InvalidOperationException($"Unexpected error upserting rows", ex);
                }
            }
        }