EFCore/src/Scaffolding/Internal/MySQLDatabaseModelFactory.cs (549 lines of code) (raw):
// Copyright © 2021, 2025, Oracle and/or its affiliates.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0, as
// published by the Free Software Foundation.
//
// This program is designed to work with certain software (including
// but not limited to OpenSSL) that is licensed under separate terms, as
// designated in a particular file or component or in included license
// documentation. The authors of MySQL hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have either included with
// the program or referenced in the documentation.
//
// Without limiting anything contained in the foregoing, this file,
// which is part of MySQL Connector/NET, is also subject to the
// Universal FOSS Exception, version 1.0, a copy of which can be found at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software Foundation, Inc.,
// 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Migrations;
using Microsoft.EntityFrameworkCore.Scaffolding;
using Microsoft.EntityFrameworkCore.Scaffolding.Metadata;
using Microsoft.Extensions.Logging;
using MySql.Data.MySqlClient;
using MySql.EntityFrameworkCore.Infrastructure.Internal;
using MySql.EntityFrameworkCore.Internal;
using MySql.EntityFrameworkCore.Properties;
using MySql.EntityFrameworkCore.Utils;
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Globalization;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
namespace MySql.EntityFrameworkCore.Scaffolding.Internal
{
internal class MySQLDatabaseModelFactory : DatabaseModelFactory
{
private readonly IDiagnosticsLogger<DbLoggerCategory.Scaffolding> _logger;
private readonly IMySQLOptions _options;
public MySQLDatabaseModelFactory([NotNull] IDiagnosticsLogger<DbLoggerCategory.Scaffolding> logger, IMySQLOptions options)
{
Check.NotNull(logger, nameof(logger));
_logger = logger;
_options = options;
}
/// <inheritdoc/>
public override DatabaseModel Create(string connectionString, DatabaseModelFactoryOptions options)
{
Check.NotEmpty(connectionString, nameof(connectionString));
Check.NotNull(options, nameof(options));
using (var connection = new MySqlConnection(connectionString))
{
return Create(connection, options);
}
}
public override DatabaseModel Create(DbConnection connection, DatabaseModelFactoryOptions options)
{
Check.NotNull(connection, nameof(connection));
Check.NotNull(options, nameof(options));
SetupMySQLOptions(connection);
var connectionStartedOpen = connection.State == ConnectionState.Open;
if (!connectionStartedOpen)
connection.Open();
try
{
var databaseModel = new DatabaseModel();
databaseModel.DatabaseName = connection.Database;
databaseModel.DefaultSchema = GetDefaultSchema(connection);
var schemaList = options.Schemas.ToList();
var schemaFilter = GenerateSchemaFilter(schemaList, databaseModel.DefaultSchema);
var tableList = options.Tables.ToList();
var tableFilter = GenerateTableFilter(tableList.Select(Parse).ToList(), schemaFilter);
var tables = GetTables(connection, tableFilter);
foreach (var table in tables)
{
table.Database = databaseModel;
databaseModel.Tables.Add(table);
}
return databaseModel;
}
finally
{
if (!connectionStartedOpen)
{
connection.Close();
}
}
}
private void SetupMySQLOptions(DbConnection connection)
{
if (_options.ConnectionSettings.Equals(new MySQLOptions().ConnectionSettings))
_options.Initialize(new DbContextOptionsBuilder()
.UseMySQL(connection).Options);
}
private string? GetDefaultSchema(DbConnection connection)
{
using (var command = connection.CreateCommand())
{
command.CommandText = "SELECT SCHEMA()";
if (command.ExecuteScalar() is string schema)
return schema;
return null;
}
}
private static (string? Schema, string Table) Parse(string table)
{
var match = _partExtractor.Match(table.Trim());
if (!match.Success)
{
throw new InvalidOperationException(string.Format(MySQLStrings.InvalidTableToIncludeInScaffolding, table));
}
var part1 = match.Groups["part1"].Value.Replace("]]", string.Empty);
var part2 = match.Groups["part2"].Value.Replace("]]", string.Empty);
return string.IsNullOrEmpty(part2) ? (null, part1) : (part1, part2);
}
private const string NamePartRegex
= @"(?:(?:\[(?<part{0}>(?:(?:\]\])|[^\]])+)\])|(?<part{0}>[^\.\[\]]+))";
private static readonly Regex _partExtractor
= new Regex(
string.Format(
CultureInfo.InvariantCulture,
@"^{0}(?:\.{1})?$",
string.Format(CultureInfo.InvariantCulture, NamePartRegex, 1),
string.Format(CultureInfo.InvariantCulture, NamePartRegex, 2)),
RegexOptions.Compiled,
TimeSpan.FromMilliseconds(1000));
private static Func<string, string>? GenerateSchemaFilter(IReadOnlyList<string> schemas, string? defaultSchema)
{
return schemas.Count > 0 || defaultSchema != null
? (s =>
{
var schemaFilterBuilder = new StringBuilder();
schemaFilterBuilder.Append(s);
schemaFilterBuilder.Append(" IN (");
if (schemas.Count > 0)
schemaFilterBuilder.Append(string.Join(", ", schemas.Select(EscapeLiteral)));
else
schemaFilterBuilder.Append(EscapeLiteral(defaultSchema!));
schemaFilterBuilder.Append(")");
return schemaFilterBuilder.ToString();
})
: null;
}
private static Func<string, string, string>? GenerateTableFilter(IReadOnlyList<(string? Schema, string Table)> tables, Func<string, string>? schemaFilter)
{
return schemaFilter != null || tables.Count > 0
? ((s, t) =>
{
var tableFilterBuilder = new StringBuilder();
var openBracket = false;
if (schemaFilter != null)
{
tableFilterBuilder
.Append("(")
.Append(schemaFilter(s));
openBracket = true;
}
if (tables.Count > 0)
{
if (openBracket)
{
tableFilterBuilder
.AppendLine()
.Append("AND ");
}
else
{
tableFilterBuilder.Append("(");
openBracket = true;
}
var tablesWithoutSchema = tables.Where(e => string.IsNullOrEmpty(e.Schema)).ToList();
if (tablesWithoutSchema.Count > 0)
{
tableFilterBuilder.Append(t);
tableFilterBuilder.Append(" IN (");
tableFilterBuilder.Append(string.Join(", ", tablesWithoutSchema.Select(e => EscapeLiteral(e.Table))));
tableFilterBuilder.Append(")");
}
var tablesWithSchema = tables.Where(e => !string.IsNullOrEmpty(e.Schema)).ToList();
if (tablesWithSchema.Count > 0)
{
if (tablesWithoutSchema.Count > 0)
tableFilterBuilder.Append(" OR ");
tableFilterBuilder.Append("CONCAT_WS(N'.',");
tableFilterBuilder.Append(string.Join(",", s, t));
tableFilterBuilder.Append(") IN (");
tableFilterBuilder.Append(string.Join(", ", tablesWithSchema.Select(e => EscapeLiteral($"{e.Schema}.{e.Table}"))));
tableFilterBuilder.Append(")");
}
}
if (openBracket)
tableFilterBuilder.Append(")");
return tableFilterBuilder.ToString();
}) : null;
}
private static ReferentialAction? ConvertToReferentialAction(string deleteAction)
{
switch (deleteAction.ToUpperInvariant())
{
case "RESTRICT":
return ReferentialAction.Restrict;
case "CASCADE":
return ReferentialAction.Cascade;
case "SET NULL":
return ReferentialAction.SetNull;
case "SET DEFAULT":
return ReferentialAction.SetDefault;
case "NO ACTION":
return ReferentialAction.NoAction;
default:
return null;
}
}
private const string GetTablesQuery = @"SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE,
IF(TABLE_COMMENT = 'VIEW' AND TABLE_TYPE = 'VIEW', '', TABLE_COMMENT) AS TABLE_COMMENT
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_TYPE IN ('BASE TABLE', 'VIEW')
AND {0};";
private const string GetPrimaryQuery = @"SELECT
`TABLE_SCHEMA`,
`TABLE_NAME`,
`INDEX_NAME`,
GROUP_CONCAT(`COLUMN_NAME` ORDER BY `SEQ_IN_INDEX` SEPARATOR ',') AS COLUMNS
FROM `INFORMATION_SCHEMA`.`STATISTICS`
WHERE {0}
AND `INDEX_NAME` = 'PRIMARY'
GROUP BY `TABLE_SCHEMA`, `TABLE_NAME`, `INDEX_NAME`, `NON_UNIQUE`;";
private const string GetColumnsQuery = @"SELECT
`TABLE_SCHEMA`,
`TABLE_NAME`,
`COLUMN_NAME`,
`ORDINAL_POSITION`,
`COLUMN_DEFAULT`,
`IS_NULLABLE`,
`DATA_TYPE`,
`CHARACTER_SET_NAME`,
`COLLATION_NAME`,
`COLUMN_TYPE`,
`COLUMN_COMMENT`,
`EXTRA`
FROM
`INFORMATION_SCHEMA`.`COLUMNS`
WHERE {0}
ORDER BY
`ORDINAL_POSITION`;";
private const string GetIndexesQuery = @"SELECT
`TABLE_SCHEMA`,
`TABLE_NAME`,
`INDEX_NAME`,
IF(`NON_UNIQUE`, 'TRUE', 'FALSE') AS NON_UNIQUE,
GROUP_CONCAT(`COLUMN_NAME` ORDER BY `SEQ_IN_INDEX` SEPARATOR ',') AS COLUMNS
FROM `INFORMATION_SCHEMA`.`STATISTICS`
WHERE {0}
AND `INDEX_NAME` <> 'PRIMARY'
GROUP BY `TABLE_SCHEMA`, `TABLE_NAME`, `INDEX_NAME`, `NON_UNIQUE`;";
private const string GetConstraintsQuery = @"SELECT
`TABLE_SCHEMA`,
`TABLE_NAME`,
`CONSTRAINT_NAME`,
`REFERENCED_TABLE_NAME`,
GROUP_CONCAT(CONCAT_WS('|', `COLUMN_NAME`, `REFERENCED_COLUMN_NAME`) ORDER BY `ORDINAL_POSITION` SEPARATOR ',') AS PAIRED_COLUMNS,
(SELECT `DELETE_RULE` FROM `INFORMATION_SCHEMA`.`REFERENTIAL_CONSTRAINTS`
WHERE `REFERENTIAL_CONSTRAINTS`.`CONSTRAINT_NAME` = `KEY_COLUMN_USAGE`.`CONSTRAINT_NAME`
AND `REFERENTIAL_CONSTRAINTS`.`CONSTRAINT_SCHEMA` = `KEY_COLUMN_USAGE`.`CONSTRAINT_SCHEMA`) AS `DELETE_RULE`
FROM `INFORMATION_SCHEMA`.`KEY_COLUMN_USAGE`
WHERE {0}
AND `CONSTRAINT_NAME` <> 'PRIMARY'
AND `REFERENCED_TABLE_NAME` IS NOT NULL
GROUP BY `TABLE_SCHEMA`, `TABLE_NAME`,
`CONSTRAINT_SCHEMA`, `CONSTRAINT_NAME`,
`TABLE_NAME`, `REFERENCED_TABLE_NAME`;";
private IEnumerable<DatabaseTable> GetTables(DbConnection connection, Func<string, string, string>? tableFilter)
{
using (var command = connection.CreateCommand())
{
var tables = new List<DatabaseTable>();
string filter = tableFilter!("TABLE_SCHEMA", "TABLE_NAME");
command.CommandText = string.Format(GetTablesQuery, filter);
using (var reader = command.ExecuteReader())
{
while (reader.Read())
{
var schema = reader.GetValueOrDefault<string>("TABLE_SCHEMA");
var name = reader.GetValueOrDefault<string>("TABLE_NAME");
var type = reader.GetValueOrDefault<string>("TABLE_TYPE");
var comment = reader.GetValueOrDefault<string>("TABLE_COMMENT");
var table = string.Equals(type, "base table", StringComparison.OrdinalIgnoreCase)
? new DatabaseTable()
: new DatabaseView();
table.Schema = schema;
table.Name = name!;
table.Comment = string.IsNullOrEmpty(comment) ? null : comment;
tables.Add(table);
}
}
// This is done separately due to MARS property may be turned off
GetColumns(connection, tables, filter);
GetPrimaryKeys(connection, tables, filter);
GetIndexes(connection, tables, filter);
GetConstraints(connection, tables, filter);
return tables;
}
}
private void GetColumns(DbConnection connection, IReadOnlyList<DatabaseTable> tables, string tableFilter)
{
using (var command = connection.CreateCommand())
{
command.CommandText = string.Format(GetColumnsQuery, tableFilter);
using (var reader = command.ExecuteReader())
{
var tableColumnGroups = reader.Cast<DbDataRecord>()
.GroupBy(
ddr => (tableSchema: ddr.GetValueOrDefault<string>("TABLE_SCHEMA"),
tableName: ddr.GetValueOrDefault<string>("TABLE_NAME")));
foreach (var tableColumnGroup in tableColumnGroups)
{
var tableSchema = tableColumnGroup.Key.tableSchema;
var tableName = tableColumnGroup.Key.tableName;
var table = tables.Single(t => t.Schema == tableSchema && t.Name == tableName);
foreach (var dataRecord in tableColumnGroup)
{
var name = dataRecord.GetValueOrDefault<string>("COLUMN_NAME");
var defaultValue = dataRecord.GetValueOrDefault<string>("COLUMN_DEFAULT");
var nullable = dataRecord.GetValueOrDefault<string>("IS_NULLABLE")!.Contains("YES");
var dataType = dataRecord.GetValueOrDefault<string>("DATA_TYPE");
var charset = dataRecord.GetValueOrDefault<string>("CHARACTER_SET_NAME");
var collation = dataRecord.GetValueOrDefault<string>("COLLATION_NAME");
var columType = dataRecord.GetValueOrDefault<string>("COLUMN_TYPE");
var extra = dataRecord.GetValueOrDefault<string>("EXTRA");
var comment = dataRecord.GetValueOrDefault<string>("COLUMN_COMMENT");
ValueGenerated valueGenerated;
if (extra!.IndexOf("auto_increment", StringComparison.Ordinal) >= 0)
{
valueGenerated = ValueGenerated.OnAdd;
}
else if (extra.IndexOf("on update", StringComparison.Ordinal) >= 0)
{
if (defaultValue != null && extra.IndexOf(defaultValue, StringComparison.Ordinal) > 0 ||
(string.Equals(dataType, "timestamp", StringComparison.OrdinalIgnoreCase) ||
string.Equals(dataType, "datetime", StringComparison.OrdinalIgnoreCase)) &&
extra.IndexOf("CURRENT_TIMESTAMP", StringComparison.Ordinal) > 0)
{
valueGenerated = ValueGenerated.OnAddOrUpdate;
}
else
{
if (defaultValue != null)
{
valueGenerated = ValueGenerated.OnAddOrUpdate;
}
else
{
valueGenerated = ValueGenerated.OnUpdate;
}
}
}
else
{
valueGenerated = ValueGenerated.Never;
}
defaultValue = FilterClrDefaults(dataType!, nullable, defaultValue);
var column = new DatabaseColumn
{
Table = table,
Name = name!,
StoreType = columType,
IsNullable = nullable,
DefaultValueSql = CreateDefaultValueString(defaultValue, dataType!),
ValueGenerated = valueGenerated,
Comment = string.IsNullOrEmpty(comment) ? null : comment
};
table.Columns.Add(column);
}
}
}
}
}
private void GetPrimaryKeys(DbConnection connection, IReadOnlyList<DatabaseTable> tables, string tableFilter)
{
using (var command = connection.CreateCommand())
{
command.CommandText = string.Format(GetPrimaryQuery, tableFilter);
using (var reader = command.ExecuteReader())
{
var tablePrimaryKeyGroups = reader.Cast<DbDataRecord>()
.GroupBy(
ddr => (tableSchema: ddr.GetValueOrDefault<string>("TABLE_SCHEMA"),
tableName: ddr.GetValueOrDefault<string>("TABLE_NAME")));
foreach (var tablePrimaryKeyGroup in tablePrimaryKeyGroups)
{
var tableSchema = tablePrimaryKeyGroup.Key.tableSchema;
var tableName = tablePrimaryKeyGroup.Key.tableName;
var table = tables.Single(t => t.Schema == tableSchema && t.Name == tableName);
foreach (var dataRecord in tablePrimaryKeyGroup)
{
try
{
var index = new DatabasePrimaryKey
{
Table = table,
Name = dataRecord.GetValueOrDefault<string>("INDEX_NAME")
};
foreach (var column in dataRecord.GetValueOrDefault<string>("COLUMNS")!.Split(','))
{
index.Columns.Add(table.Columns.Single(y => y.Name == column));
}
table.PrimaryKey = index;
}
catch (Exception ex)
{
_logger.Logger.LogError(ex, "Error assigning PK for {table}.", table.Name);
}
}
}
}
}
}
private void GetIndexes(DbConnection connection, IReadOnlyList<DatabaseTable> tables, string tableFilter)
{
using (var command = connection.CreateCommand())
{
command.CommandText = string.Format(GetIndexesQuery, tableFilter);
using (var reader = command.ExecuteReader())
{
var tableIndexGroups = reader.Cast<DbDataRecord>()
.GroupBy(
ddr => (tableSchema: ddr.GetValueOrDefault<string>("TABLE_SCHEMA"),
tableName: ddr.GetValueOrDefault<string>("TABLE_NAME")));
foreach (var tableIndexGroup in tableIndexGroups)
{
var tableSchema = tableIndexGroup.Key.tableSchema;
var tableName = tableIndexGroup.Key.tableName;
var table = tables.Single(t => t.Schema == tableSchema && t.Name == tableName);
foreach (var dataRecord in tableIndexGroup)
{
try
{
var index = new DatabaseIndex
{
Table = table,
Name = dataRecord.GetValueOrDefault<string>("INDEX_NAME"),
IsUnique = !bool.Parse(dataRecord.GetValueOrDefault<string>("NON_UNIQUE")!)
};
foreach (var column in dataRecord.GetValueOrDefault<string>("COLUMNS")!.Split(','))
{
index.Columns.Add(table.Columns.Single(y => y.Name == column));
}
table.Indexes.Add(index);
}
catch (Exception ex)
{
_logger.Logger.LogError(ex, "Error assigning index for {table}.", table.Name);
}
}
}
}
}
}
private void GetConstraints(DbConnection connection, IReadOnlyList<DatabaseTable> tables, string tableFilter)
{
using (var command = connection.CreateCommand())
{
command.CommandText = string.Format(GetConstraintsQuery, tableFilter);
using (var reader = command.ExecuteReader())
{
var tableConstraintGroups = reader.Cast<DbDataRecord>()
.GroupBy(
ddr => (tableSchema: ddr.GetValueOrDefault<string>("TABLE_SCHEMA"),
tableName: ddr.GetValueOrDefault<string>("TABLE_NAME")));
foreach (var tableConstraintGroup in tableConstraintGroups)
{
var tableSchema = tableConstraintGroup.Key.tableSchema;
var tableName = tableConstraintGroup.Key.tableName;
var table = tables.Single(t => t.Schema == tableSchema && t.Name == tableName);
foreach (var dataRecord in tableConstraintGroup)
{
var referencedTableName = dataRecord.GetValueOrDefault<string>("REFERENCED_TABLE_NAME");
var referencedTable = tables.FirstOrDefault(t => t.Name == referencedTableName);
if (referencedTable != null)
{
var fkInfo = new DatabaseForeignKey
{
Name = dataRecord.GetValueOrDefault<string>("CONSTRAINT_NAME"),
OnDelete = ConvertToReferentialAction(dataRecord.GetValueOrDefault<string>("DELETE_RULE")!),
Table = table,
PrincipalTable = referencedTable
};
foreach (var pair in dataRecord.GetValueOrDefault<string>("PAIRED_COLUMNS")!.Split(','))
{
fkInfo.Columns.Add(table.Columns.Single(y =>
string.Equals(y.Name, pair.Split('|')[0], StringComparison.OrdinalIgnoreCase)));
fkInfo.PrincipalColumns.Add(fkInfo.PrincipalTable.Columns.Single(y =>
string.Equals(y.Name, pair.Split('|')[1], StringComparison.OrdinalIgnoreCase)));
}
table.ForeignKeys.Add(fkInfo);
}
else
{
_logger.Logger.LogWarning($"Referenced table `{referencedTableName}` is not in dictionary.");
}
}
}
}
}
}
private static string? FilterClrDefaults(string dataTypeName, bool nullable, string? defaultValue)
{
if (defaultValue == null)
{
return null;
}
if (nullable)
{
return defaultValue;
}
if (defaultValue == "0")
{
if (dataTypeName == "bit"
|| dataTypeName == "tinyint"
|| dataTypeName == "smallint"
|| dataTypeName == "int"
|| dataTypeName == "bigint"
|| dataTypeName == "decimal"
|| dataTypeName == "double"
|| dataTypeName == "float")
{
return null;
}
}
else if (Regex.IsMatch(defaultValue, @"^0\.0+$"))
{
if (dataTypeName == "decimal"
|| dataTypeName == "double"
|| dataTypeName == "float")
{
return null;
}
}
return defaultValue;
}
private string? CreateDefaultValueString(string? defaultValue, string dataType)
{
if (defaultValue == null)
{
return null;
}
if ((string.Equals(dataType, "timestamp", StringComparison.OrdinalIgnoreCase) ||
string.Equals(dataType, "datetime", StringComparison.OrdinalIgnoreCase)) &&
string.Equals(defaultValue, "CURRENT_TIMESTAMP", StringComparison.OrdinalIgnoreCase))
{
return defaultValue;
}
// Handle bit values.
if (string.Equals(dataType, "bit", StringComparison.OrdinalIgnoreCase)
&& defaultValue.StartsWith("b'"))
{
return defaultValue;
}
return "'" + defaultValue.Replace(@"\", @"\\").Replace("'", "''") + "'";
}
private static string EscapeLiteral(string s)
{
return $"N'{s}'";
}
}
}