From 7fcc398c455e396929bef7f73dc9a2365855b789 Mon Sep 17 00:00:00 2001 From: Stef Heyenrath Date: Sat, 7 Dec 2024 17:01:06 +0100 Subject: [PATCH] Add support for SequenceEqual --- .../Parser/ExpressionParser.cs | 41 +++++++++++-------- .../ExpressionParserTests.SequenceEqual.cs | 31 ++++++++++++++ 2 files changed, 54 insertions(+), 18 deletions(-) create mode 100644 test/System.Linq.Dynamic.Core.Tests/Parser/ExpressionParserTests.SequenceEqual.cs diff --git a/src/System.Linq.Dynamic.Core/Parser/ExpressionParser.cs b/src/System.Linq.Dynamic.Core/Parser/ExpressionParser.cs index 4f7a598e..2525a0d6 100644 --- a/src/System.Linq.Dynamic.Core/Parser/ExpressionParser.cs +++ b/src/System.Linq.Dynamic.Core/Parser/ExpressionParser.cs @@ -1822,8 +1822,7 @@ private Expression ParseMemberAccess(Type? type, Expression? expression, string? if (isApplicableForEnumerable && TypeHelper.TryFindGenericType(typeof(IEnumerable<>), type, out var enumerableType)) { - var elementType = enumerableType.GetTypeInfo().GetGenericTypeArguments()[0]; - if (TryParseEnumerable(expression!, elementType, id, errorPos, type, out args, out var enumerableExpression)) + if (TryParseEnumerable(expression!, enumerableType, id, errorPos, type, out args, out var enumerableExpression)) { return enumerableExpression; } @@ -2061,8 +2060,10 @@ private Expression ParseAsEnumOrNestedClass(string id) return ParseMemberAccess(type, null, identifier); } - private bool TryParseEnumerable(Expression instance, Type elementType, string methodName, int errorPos, Type? type, out Expression[]? args, [NotNullWhen(true)] out Expression? expression) + private bool TryParseEnumerable(Expression instance, Type enumerableType, string methodName, int errorPos, Type? type, out Expression[]? args, [NotNullWhen(true)] out Expression? expression) { + var elementType = enumerableType.GetTypeInfo().GetGenericTypeArguments()[0]; + // Keep the current _parent. var oldParent = _parent; @@ -2124,7 +2125,7 @@ private bool TryParseEnumerable(Expression instance, Type elementType, string me // #633 - For Average without any arguments, try to find the non-generic Average method on the callType for the supplied parameter type. if (methodName == nameof(Enumerable.Average) && args.Length == 0 && _methodFinder.TryFindAverageMethod(callType, theType, out var averageMethod)) { - expression = Expression.Call(null, averageMethod, new[] { instance }); + expression = Expression.Call(null, averageMethod, [instance]); return true; } @@ -2136,56 +2137,60 @@ private bool TryParseEnumerable(Expression instance, Type elementType, string me throw ParseError(_textParser.CurrentToken.Pos, Res.FunctionRequiresOneArg, methodName); } - typeArgs = new[] { ResolveTypeFromArgumentExpression(methodName, args[0]) }; - args = new Expression[0]; + typeArgs = [ResolveTypeFromArgumentExpression(methodName, args[0])]; + args = []; } else if (new[] { "Max", "Min", "Select", "OrderBy", "OrderByDescending", "ThenBy", "ThenByDescending", "GroupBy" }.Contains(methodName)) { if (args.Length == 2) { - typeArgs = new[] { elementType, args[0].Type, args[1].Type }; + typeArgs = [elementType, args[0].Type, args[1].Type]; } else if (args.Length == 1) { - typeArgs = new[] { elementType, args[0].Type }; + typeArgs = [elementType, args[0].Type]; } else { - typeArgs = new[] { elementType }; + typeArgs = [elementType]; } } else if (methodName == "SelectMany") { var bodyType = Expression.Lambda(args[0], innerIt).Body.Type; - var interfaces = bodyType.GetInterfaces().Union(new[] { bodyType }); - Type interfaceType = interfaces.Single(i => i.Name == typeof(IEnumerable<>).Name); - Type resultType = interfaceType.GetTypeInfo().GetGenericTypeArguments()[0]; - typeArgs = new[] { elementType, resultType }; + var interfaces = bodyType.GetInterfaces().Union([bodyType]); + var interfaceType = interfaces.Single(i => i.Name == typeof(IEnumerable<>).Name); + var resultType = interfaceType.GetTypeInfo().GetGenericTypeArguments()[0]; + typeArgs = [elementType, resultType]; } else { - typeArgs = new[] { elementType }; + typeArgs = [elementType]; } if (args.Length == 0) { - args = new[] { instance }; + args = [instance]; } else { if (new[] { "Concat", "Contains", "ContainsKey", "DefaultIfEmpty", "Except", "Intersect", "Skip", "Take", "Union" }.Contains(methodName)) { - args = new[] { instance, args[0] }; + args = [instance, args[0]]; + } + else if (methodName == nameof(Enumerable.SequenceEqual)) + { + args = [instance, args[0]]; } else { if (args.Length == 2) { - args = new[] { instance, Expression.Lambda(args[0], innerIt), Expression.Lambda(args[1], innerIt) }; + args = [instance, Expression.Lambda(args[0], innerIt), Expression.Lambda(args[1], innerIt)]; } else { - args = new[] { instance, Expression.Lambda(args[0], innerIt) }; + args = [instance, Expression.Lambda(args[0], innerIt)]; } } } diff --git a/test/System.Linq.Dynamic.Core.Tests/Parser/ExpressionParserTests.SequenceEqual.cs b/test/System.Linq.Dynamic.Core.Tests/Parser/ExpressionParserTests.SequenceEqual.cs new file mode 100644 index 00000000..bb014bfc --- /dev/null +++ b/test/System.Linq.Dynamic.Core.Tests/Parser/ExpressionParserTests.SequenceEqual.cs @@ -0,0 +1,31 @@ +using System.Linq.Dynamic.Core.Parser; +using System.Linq.Expressions; +using Xunit; + +namespace System.Linq.Dynamic.Core.Tests.Parser; + +partial class ExpressionParserTests +{ + [Fact] + public void Parse_SequenceEqual() + { + // Arrange + var parameter = Expression.Parameter(typeof(Entity), "Entity"); + + var parser = new ExpressionParser( + [parameter], + "Entity.ArrayA.SequenceEqual(Entity.ArrayB)", + null, + null); + + // Act + parser.Parse(typeof(bool)); + } + + public class Entity + { + public string[] ArrayA { get; set; } = []; + + public string[] ArrayB { get; set; } = []; + } +} \ No newline at end of file