modules/platforms/dotnet/Apache.Ignite/Internal/Linq/MethodVisitor.cs (381 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
namespace Apache.Ignite.Internal.Linq;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text.RegularExpressions;
using Common;
using NodaTime;
/// <summary>
/// MethodCall expression visitor. Maps CLR methods to SQL functions.
/// <para />
/// Refer to https://calcite.apache.org/docs/reference.html for supported SQL functions.
/// </summary>
internal static class MethodVisitor
{
private const string TrimBoth = "both";
private const string TrimLeading = "leading";
private const string TrimTrailing = "trailing";
/// <summary> Property visitors. </summary>
private static readonly Dictionary<MemberInfo, string> Properties = new()
{
{typeof(string).GetProperty(nameof(string.Length))!, "length"},
{typeof(LocalDate).GetProperty(nameof(LocalDate.Year))!, "year"},
{typeof(LocalDate).GetProperty(nameof(LocalDate.Month))!, "month"},
{typeof(LocalDate).GetProperty(nameof(LocalDate.Day))!, "dayofmonth"},
{typeof(LocalDate).GetProperty(nameof(LocalDate.DayOfYear))!, "dayofyear"},
{typeof(LocalDate).GetProperty(nameof(LocalDate.DayOfWeek))!, "-1 + dayofweek"},
{typeof(LocalTime).GetProperty(nameof(LocalTime.Hour))!, "hour"},
{typeof(LocalTime).GetProperty(nameof(LocalTime.Minute))!, "minute"},
{typeof(LocalTime).GetProperty(nameof(LocalTime.Second))!, "second"},
{typeof(LocalDateTime).GetProperty(nameof(LocalDateTime.Year))!, "year"},
{typeof(LocalDateTime).GetProperty(nameof(LocalDateTime.Month))!, "month"},
{typeof(LocalDateTime).GetProperty(nameof(LocalDateTime.Day))!, "dayofmonth"},
{typeof(LocalDateTime).GetProperty(nameof(LocalDateTime.DayOfYear))!, "dayofyear"},
{typeof(LocalDateTime).GetProperty(nameof(LocalDateTime.DayOfWeek))!, "-1 + dayofweek"},
{typeof(LocalDateTime).GetProperty(nameof(LocalDateTime.Hour))!, "hour"},
{typeof(LocalDateTime).GetProperty(nameof(LocalDateTime.Minute))!, "minute"},
{typeof(LocalDateTime).GetProperty(nameof(LocalDateTime.Second))!, "second"}
};
/// <summary>
/// Delegates dictionary.
/// </summary>
private static readonly Dictionary<MethodInfo, VisitMethodDelegate> Delegates = new List
<KeyValuePair<MethodInfo?, VisitMethodDelegate>>
{
GetStringMethod(nameof(string.ToLower), Type.EmptyTypes, GetFunc("lower")),
GetStringMethod(nameof(string.ToUpper), Type.EmptyTypes, GetFunc("upper")),
GetStringMethod(nameof(string.Contains), new[] {typeof(string)}, (e, v) => VisitSqlLike(e, v, "'%' || ? || '%'")),
GetStringMethod(nameof(string.StartsWith), new[] {typeof(string)}, (e, v) => VisitSqlLike(e, v, "? || '%'")),
GetStringMethod(nameof(string.EndsWith), new[] {typeof(string)}, (e, v) => VisitSqlLike(e, v, "'%' || ?")),
GetStringMethod(nameof(string.IndexOf), new[] {typeof(string)}, VisitPositionFunc),
GetStringMethod(nameof(string.IndexOf), new[] {typeof(string), typeof(int)}, VisitPositionFunc),
GetStringMethod(nameof(string.Substring), new[] {typeof(int)}, GetFunc("substring", 0, 1)),
GetStringMethod(nameof(string.Substring), new[] {typeof(int), typeof(int)}, GetFunc("substring", inlineConstArgs: true, 0, 1)),
GetStringMethod(nameof(string.Trim), "trim"),
GetStringMethod(nameof(string.TrimStart), "ltrim"),
GetStringMethod(nameof(string.TrimEnd), "rtrim"),
GetParameterizedTrimMethod(nameof(string.Trim), TrimBoth),
GetParameterizedTrimMethod(nameof(string.TrimStart), TrimLeading),
GetParameterizedTrimMethod(nameof(string.TrimEnd), TrimTrailing),
GetCharTrimMethod(nameof(string.Trim), TrimBoth),
GetCharTrimMethod(nameof(string.TrimStart), TrimLeading),
GetCharTrimMethod(nameof(string.TrimEnd), TrimTrailing),
GetStringMethod(nameof(string.Replace), "replace", typeof(string), typeof(string)),
GetStringMethod(nameof(string.Compare), new[] { typeof(string), typeof(string) }, (e, v) => VisitStringCompare(e, v, false)),
GetStringMethod(nameof(string.Compare), new[] { typeof(string), typeof(string), typeof(bool) }, (e, v) => VisitStringCompare(e, v, GetStringCompareIgnoreCaseParameter(e.Arguments[2]))),
GetRegexMethod(nameof(Regex.Replace), "regexp_replace", typeof(string), typeof(string), typeof(string)),
GetRegexMethod(nameof(Regex.Replace), "regexp_replace", typeof(string), typeof(string), typeof(string), typeof(RegexOptions)),
GetRegexMethod(nameof(Regex.IsMatch), "regexp_like", typeof(string), typeof(string)),
GetRegexMethod(nameof(Regex.IsMatch), "regexp_like", typeof(string), typeof(string), typeof(RegexOptions)),
GetMethod(typeof(DateTime), "ToString", new[] {typeof(string)}, (e, v) => VisitFunc(e, v, "formatdatetime", ", 'en', 'UTC'", false)),
GetMathMethod(nameof(Math.Abs), typeof(int)),
GetMathMethod(nameof(Math.Abs), typeof(long)),
GetMathMethod(nameof(Math.Abs), typeof(float)),
GetMathMethod(nameof(Math.Abs), typeof(double)),
GetMathMethod(nameof(Math.Abs), typeof(decimal)),
GetMathMethod(nameof(Math.Abs), typeof(sbyte)),
GetMathMethod(nameof(Math.Abs), typeof(short)),
GetMathMethod(nameof(Math.Acos), typeof(double)),
GetMathMethod(nameof(Math.Acosh), typeof(double)),
GetMathMethod(nameof(Math.Asin), typeof(double)),
GetMathMethod(nameof(Math.Asinh), typeof(double)),
GetMathMethod(nameof(Math.Atan), typeof(double)),
GetMathMethod(nameof(Math.Atanh), typeof(double)),
GetMathMethod(nameof(Math.Atan2), typeof(double), typeof(double)),
GetMathMethod(nameof(Math.Ceiling), typeof(double)),
GetMathMethod(nameof(Math.Ceiling), typeof(decimal)),
GetMathMethod(nameof(Math.Cos), typeof(double)),
GetMathMethod(nameof(Math.Cosh), typeof(double)),
GetMathMethod(nameof(Math.Exp), typeof(double)),
GetMathMethod(nameof(Math.Floor), typeof(double)),
GetMathMethod(nameof(Math.Floor), typeof(decimal)),
GetMathMethod(nameof(Math.Log), "Ln", inlineCostArgs: false, typeof(double)),
GetMathMethod(nameof(Math.Log10), typeof(double)),
GetMathMethod(nameof(Math.Log2), typeof(double)),
GetMathMethod(nameof(Math.Pow), "Power", inlineCostArgs: true, typeof(double), typeof(double)),
GetMathMethod(nameof(Math.Round), typeof(double)),
GetMathMethod(nameof(Math.Round), typeof(double), typeof(int)),
GetMathMethod(nameof(Math.Round), typeof(decimal)),
GetMathMethod(nameof(Math.Round), typeof(decimal), typeof(int)),
GetMathMethod(nameof(Math.Sign), typeof(double)),
GetMathMethod(nameof(Math.Sign), typeof(decimal)),
GetMathMethod(nameof(Math.Sign), typeof(float)),
GetMathMethod(nameof(Math.Sign), typeof(int)),
GetMathMethod(nameof(Math.Sign), typeof(long)),
GetMathMethod(nameof(Math.Sign), typeof(short)),
GetMathMethod(nameof(Math.Sign), typeof(sbyte)),
GetMathMethod(nameof(Math.Sin), typeof(double)),
GetMathMethod(nameof(Math.Sinh), typeof(double)),
GetMathMethod(nameof(Math.Sqrt), typeof(double)),
GetMathMethod(nameof(Math.Tan), typeof(double)),
GetMathMethod(nameof(Math.Tanh), typeof(double)),
GetMathMethod(nameof(Math.Truncate), typeof(double)),
GetMathMethod(nameof(Math.Truncate), typeof(decimal)),
}
.Where(x => x.Key != null)
.ToDictionary(x => x.Key!, x => x.Value);
/// <summary> RegexOptions transformations. </summary>
private static readonly Dictionary<RegexOptions, string> RegexOptionFlags = new()
{
{ RegexOptions.IgnoreCase, "i" },
{ RegexOptions.Multiline, "m" }
};
/// <summary> Method visit delegate. </summary>
[SuppressMessage("Naming", "CA1711:Identifiers should not have incorrect suffix", Justification = "Private.")]
private delegate void VisitMethodDelegate(MethodCallExpression expression, IgniteQueryExpressionVisitor visitor);
/// <summary>
/// Visits a property call expression.
/// </summary>
/// <param name="expression">Expression.</param>
/// <param name="visitor">Visitor.</param>
/// <returns>Success flag.</returns>
public static bool VisitPropertyCall(MemberExpression expression, IgniteQueryExpressionVisitor visitor)
{
if (!Properties.TryGetValue(expression.Member, out var funcName) || expression.Expression == null)
{
return false;
}
visitor.ResultBuilder.Append(funcName).Append('(');
visitor.Visit(expression.Expression);
visitor.ResultBuilder.Append(')');
return true;
}
/// <summary>
/// Visits a method call expression.
/// </summary>
/// <param name="expression">Expression.</param>
/// <param name="visitor">Visitor.</param>
public static void VisitMethodCall(MethodCallExpression expression, IgniteQueryExpressionVisitor visitor)
{
var mtd = expression.Method;
if (!Delegates.TryGetValue(mtd, out var del))
{
throw new NotSupportedException(string.Format(
CultureInfo.InvariantCulture,
"Method not supported: {0}.({1})",
mtd.DeclaringType == null ? "static" : mtd.DeclaringType.FullName,
mtd));
}
del(expression, visitor);
}
/// <summary>
/// Visits a constant call expression.
/// </summary>
/// <param name="expression">Expression.</param>
/// <param name="visitor">Visitor.</param>
/// <returns>Success flag.</returns>
public static bool VisitConstantCall(ConstantExpression expression, IgniteQueryExpressionVisitor visitor)
{
if (expression.Type != typeof(RegexOptions))
{
return false;
}
var regexOptions = expression.Value as RegexOptions? ?? RegexOptions.None;
var result = string.Empty;
foreach (var option in RegexOptionFlags)
{
if (regexOptions.HasFlag(option.Key))
{
result += option.Value;
regexOptions &= ~option.Key;
}
}
if (regexOptions != RegexOptions.None)
{
throw new NotSupportedException($"RegexOptions.{regexOptions} is not supported");
}
// "pos" and "occurence" are required before "matchType".
visitor.ResultBuilder.Append("1, 1, ");
visitor.AppendParameter(result);
return true;
}
/// <summary>
/// Gets the function.
/// </summary>
private static VisitMethodDelegate GetFunc(string func, params int[] adjust) =>
(e, v) => VisitFunc(e, v, func, null, false, adjust);
/// <summary>
/// Gets the function.
/// </summary>
private static VisitMethodDelegate GetFunc(string func, bool inlineConstArgs, params int[] adjust) =>
(e, v) => VisitFunc(e, v, func, null, inlineConstArgs, adjust);
/// <summary>
/// Visits the instance function.
/// </summary>
private static void VisitFunc(
MethodCallExpression expression,
IgniteQueryExpressionVisitor visitor,
string func,
string? suffix,
bool inlineConstArgs,
params int[] adjust)
{
visitor.ResultBuilder.Append(func).Append('(');
var isInstanceMethod = expression.Object != null;
if (isInstanceMethod)
{
visitor.Visit(expression.Object!);
}
for (int i = 0; i < expression.Arguments.Count; i++)
{
var arg = expression.Arguments[i];
if (isInstanceMethod || (i > 0))
{
visitor.ResultBuilder.Append(", ");
}
if (inlineConstArgs &&
arg is ConstantExpression constExpr &&
constExpr.Type.IsPrimitive &&
constExpr.Type != typeof(char))
{
// TODO IGNITE-18258 Remove this logic, we should be able to pass args as SQL params for all functions.
// We only allow inline for numeric types. Other types can lead to SQL injections.
visitor.ResultBuilder.Append(constExpr.Value);
}
else
{
visitor.Visit(arg);
}
AppendAdjustment(visitor, adjust, i + 1);
}
visitor.ResultBuilder.Append(suffix).Append(')');
AppendAdjustment(visitor, adjust, 0);
}
/// <summary>
/// Visits the instance function for Trim specific handling.
/// </summary>
private static void VisitParameterizedTrimFunc(
MethodCallExpression expression,
IgniteQueryExpressionVisitor visitor,
string mode)
{
// trim(leading|trailing|both chars from string)
visitor.ResultBuilder.Append("trim(").Append(mode).Append(' ');
if (expression.Arguments.Count > 0 && expression.Arguments[0] is { } arg)
{
if (arg is ConstantExpression constant)
{
if (constant.Value is char ch)
{
visitor.AppendParameter(ch);
}
else
{
var args = constant.Value as IEnumerable<char>;
if (args == null)
{
throw new NotSupportedException("String.Trim function only supports IEnumerable<char>");
}
var enumeratedArgs = args.ToArray();
if (enumeratedArgs.Length != 1)
{
throw new NotSupportedException("String.Trim function only supports a single argument: " + expression);
}
visitor.AppendParameter(enumeratedArgs[0]);
}
}
else
{
visitor.Visit(arg);
}
}
visitor.ResultBuilder.TrimEnd().Append(" from ");
visitor.Visit(expression.Object!);
visitor.ResultBuilder.Append(')');
}
/// <summary>
/// Visits the function for IndexOf -> POSITION mapping.
/// </summary>
private static void VisitPositionFunc(
MethodCallExpression expression,
IgniteQueryExpressionVisitor visitor)
{
// POSITION(string1 IN string2)
// Returns 1-based index when substring is found, 0 when not found.
visitor.ResultBuilder.Append("-1 + position(");
Debug.Assert(expression.Arguments.Count >= 1, "expression.Arguments.Count >= 1");
visitor.Visit(expression.Arguments[0]);
visitor.ResultBuilder.TrimEnd().Append(" in ");
visitor.Visit(expression.Object!);
if (expression.Arguments.Count > 1)
{
// POSITION(string1 IN string2 FROM integer)
visitor.ResultBuilder.TrimEnd().Append(" from (");
visitor.Visit(expression.Arguments[1]);
visitor.ResultBuilder.Append(" + 1)");
}
visitor.ResultBuilder.TrimEnd().Append(')');
}
/// <summary>
/// Appends the adjustment.
/// </summary>
private static void AppendAdjustment(IgniteQueryExpressionVisitor visitor, int[] adjust, int idx)
{
if (idx < adjust.Length)
{
var delta = adjust[idx];
if (delta > 0)
{
visitor.ResultBuilder.AppendFormat(CultureInfo.InvariantCulture, " + {0}", delta);
}
else if (delta < 0)
{
visitor.ResultBuilder.AppendFormat(CultureInfo.InvariantCulture, " {0}", delta);
}
}
}
/// <summary>
/// Visits the SQL like expression.
/// </summary>
private static void VisitSqlLike(
MethodCallExpression expression,
IgniteQueryExpressionVisitor visitor,
string likeFormat)
{
visitor.ResultBuilder.Append('(');
visitor.Visit(expression.Object!);
visitor.ResultBuilder.AppendFormat(CultureInfo.InvariantCulture, " like {0}) ", likeFormat);
var paramValue = expression.Arguments[0] is ConstantExpression arg
? arg.Value
: ExpressionWalker.EvaluateExpression<object>(expression.Arguments[0]);
visitor.Parameters.Add(paramValue);
}
/// <summary>
/// Get IgnoreCase parameter for string.Compare method.
/// </summary>
private static bool GetStringCompareIgnoreCaseParameter(Expression expression)
{
if (expression is ConstantExpression { Value: bool } constant)
{
return (bool)constant.Value;
}
throw new NotSupportedException(
"Parameter 'ignoreCase' from 'string.Compare method should be specified as a constant expression");
}
/// <summary>
/// Visits string.Compare method.
/// </summary>
private static void VisitStringCompare(MethodCallExpression expression, IgniteQueryExpressionVisitor visitor, bool ignoreCase)
{
// case when (A is not distinct from B) then 0 else (case (A > B) when true then 1 else -1 end) end
var builder = visitor.ResultBuilder;
builder.Append("case when (");
VisitArg(visitor, expression, 0, ignoreCase);
builder.Append(" is not distinct from ");
VisitArg(visitor, expression, 1, ignoreCase);
builder.Append(") then 0 else (case when (");
VisitArg(visitor, expression, 0, ignoreCase);
builder.Append(" > ");
VisitArg(visitor, expression, 1, ignoreCase);
builder.Append(") then 1 else -1 end) end");
}
/// <summary>
/// Visits member expression argument.
/// </summary>
private static void VisitArg(
IgniteQueryExpressionVisitor visitor,
MethodCallExpression expression,
int idx,
bool lower)
{
if (lower)
{
visitor.ResultBuilder.Append("lower(");
}
visitor.Visit(expression.Arguments[idx]);
if (lower)
{
visitor.ResultBuilder.Append(')');
}
}
/// <summary>
/// Gets the method.
/// </summary>
private static KeyValuePair<MethodInfo?, VisitMethodDelegate> GetMethod(
Type type,
string name,
Type[]? argTypes = null,
VisitMethodDelegate? del = null,
bool inlineConstArgs = false)
{
var method = argTypes == null ? type.GetMethod(name) : type.GetMethod(name, argTypes);
return new KeyValuePair<MethodInfo?, VisitMethodDelegate>(method!, del ?? GetFunc(name, inlineConstArgs));
}
/// <summary>
/// Gets the string method.
/// </summary>
private static KeyValuePair<MethodInfo?, VisitMethodDelegate> GetStringMethod(
string name,
Type[]? argTypes = null,
VisitMethodDelegate? del = null)
{
return GetMethod(typeof(string), name, argTypes, del);
}
/// <summary>
/// Gets the string method.
/// </summary>
private static KeyValuePair<MethodInfo?, VisitMethodDelegate> GetStringMethod(
string name,
string sqlName,
params Type[] argTypes)
{
return GetMethod(typeof(string), name, argTypes, GetFunc(sqlName));
}
/// <summary>
/// Gets the Regex method.
/// </summary>
private static KeyValuePair<MethodInfo?, VisitMethodDelegate> GetRegexMethod(
string name,
string sqlName,
params Type[] argTypes)
{
return GetMethod(typeof(Regex), name, argTypes, GetFunc(sqlName));
}
/// <summary>
/// Gets string parameterized Trim(TrimStart, TrimEnd) method.
/// </summary>
private static KeyValuePair<MethodInfo?, VisitMethodDelegate> GetParameterizedTrimMethod(string name, string mode) =>
GetMethod(
typeof(string),
name,
new[] {typeof(char[])},
(e, v) => VisitParameterizedTrimFunc(e, v, mode));
/// <summary>
/// Gets string parameterized Trim(TrimStart, TrimEnd) method that takes a single char.
/// </summary>
private static KeyValuePair<MethodInfo?, VisitMethodDelegate> GetCharTrimMethod(string name, string mode) =>
GetMethod(
typeof(string),
name,
new[] {typeof(char)},
(e, v) => VisitParameterizedTrimFunc(e, v, mode));
/// <summary>
/// Gets the math method.
/// </summary>
private static KeyValuePair<MethodInfo?, VisitMethodDelegate> GetMathMethod(
string name,
string sqlName,
bool inlineCostArgs,
params Type[] argTypes) =>
GetMethod(typeof(Math), name, argTypes, GetFunc(sqlName, inlineCostArgs), inlineCostArgs);
/// <summary>
/// Gets the math method.
/// </summary>
private static KeyValuePair<MethodInfo?, VisitMethodDelegate> GetMathMethod(string name, params Type[] argTypes) =>
GetMathMethod(name, name, false, argTypes);
}