diff --git a/src/Adapter/MSTest.TestAdapter/Extensions/MethodInfoExtensions.cs b/src/Adapter/MSTest.TestAdapter/Extensions/MethodInfoExtensions.cs index f89c971e2a..c8b23bf4f6 100644 --- a/src/Adapter/MSTest.TestAdapter/Extensions/MethodInfoExtensions.cs +++ b/src/Adapter/MSTest.TestAdapter/Extensions/MethodInfoExtensions.cs @@ -197,6 +197,42 @@ internal static void InvokeAsSynchronousTask(this MethodInfo methodInfo, object? } } + private static void InferGenerics(Type parameterType, Type argumentType, List<(Type ParameterType, Type Substitution)> result) + { + if (parameterType.IsGenericMethodParameter()) + { + // We found a generic parameter. The argument type should be the substitution for it. + result.Add((parameterType, argumentType)); + return; + } + + if (!parameterType.ContainsGenericParameters) + { + // We don't have any generics. + return; + } + + if (parameterType.GetElementType() is { } parameterTypeElementType && + argumentType.GetElementType() is { } argumentTypeElementType) + { + // If we have arrays, we need to infer the generic types for the element types. + // For example, if parameterType is `T[]` and argumentType is `string[]`, we need to infer that `T` is `string`. + // So, we call InferGenerics with `T` and `string`. + InferGenerics(parameterTypeElementType, argumentTypeElementType, result); + return; + } + else if (parameterType.GenericTypeArguments.Length == argumentType.GenericTypeArguments.Length) + { + for (int i = 0; i < parameterType.GenericTypeArguments.Length; i++) + { + if (parameterType.GenericTypeArguments[i].ContainsGenericParameters) + { + InferGenerics(parameterType.GenericTypeArguments[i], argumentType.GenericTypeArguments[i], result); + } + } + } + } + // Scenarios to test: // // [DataRow(null, "Hello")] @@ -228,26 +264,30 @@ private static MethodInfo ConstructGenericMethod(MethodInfo methodInfo, object?[ for (int i = 0; i < parameters.Length; i++) { Type parameterType = parameters[i].ParameterType; - if (!parameterType.IsGenericMethodParameter() || arguments[i] is null) + if (!parameterType.ContainsGenericParameters || arguments[i] is null) { continue; } - Type substitution = arguments[i]!/*Very strange nullability warning*/.GetType(); - int mapIndexForParameter = GetMapIndexForParameterType(parameterType, map); - Type? existingSubstitution = map[mapIndexForParameter].Substitution; - - if (existingSubstitution is null || substitution.IsAssignableFrom(existingSubstitution)) - { - map[mapIndexForParameter] = (parameterType, substitution); - } - else if (existingSubstitution.IsAssignableFrom(substitution)) - { - // Do nothing. We already have a good existing substitution. - } - else + var result = new List<(Type ParameterType, Type Substitution)>(); + InferGenerics(parameterType, arguments[i]!/*Very strange nullability warning*/.GetType(), result); + foreach ((Type genericParameterType, Type substitution) in result) { - throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, Resource.GenericParameterConflict, parameterType.Name, existingSubstitution, substitution)); + int mapIndexForParameter = GetMapIndexForParameterType(genericParameterType, map); + Type? existingSubstitution = map[mapIndexForParameter].Substitution; + + if (existingSubstitution is null || substitution.IsAssignableFrom(existingSubstitution)) + { + map[mapIndexForParameter] = (genericParameterType, substitution); + } + else if (existingSubstitution.IsAssignableFrom(substitution)) + { + // Do nothing. We already have a good existing substitution. + } + else + { + throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, Resource.GenericParameterConflict, parameterType.Name, existingSubstitution, substitution)); + } } } diff --git a/src/Analyzers/MSTest.Analyzers/DataRowShouldBeValidAnalyzer.cs b/src/Analyzers/MSTest.Analyzers/DataRowShouldBeValidAnalyzer.cs index 5c537e169f..a92d269aa8 100644 --- a/src/Analyzers/MSTest.Analyzers/DataRowShouldBeValidAnalyzer.cs +++ b/src/Analyzers/MSTest.Analyzers/DataRowShouldBeValidAnalyzer.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Collections.Immutable; @@ -193,10 +193,11 @@ private static void AnalyzeAttribute(SymbolAnalysisContext context, AttributeDat ITypeSymbol? argumentType = constructorArguments[currentArgumentIndex].Type; ITypeSymbol paramType = GetParameterType(methodSymbol.Parameters, currentArgumentIndex, constructorArguments.Length); - if (paramType.TypeKind == TypeKind.TypeParameter) + if (paramType.TypeKind == TypeKind.TypeParameter || + paramType is IArrayTypeSymbol { ElementType.TypeKind: TypeKind.TypeParameter }) { // That means the actual type cannot be determined. We should have issued a separate - // diagnostic for that in GetParameterTypeSubstitutions call above. + // diagnostic for that in AnalyzeGenericMethod call above. continue; } @@ -215,6 +216,44 @@ private static void AnalyzeAttribute(SymbolAnalysisContext context, AttributeDat } } + private static Type GetSystemType(ITypeSymbol type) + { + if (type.TypeKind == TypeKind.Enum) + { + if (((INamedTypeSymbol)type).EnumUnderlyingType is { } underlyingType) + { + type = underlyingType; + } + else + { + // If this is reachable, it will be an error scenario. + return typeof(int); + } + } + + return type.SpecialType switch + { + SpecialType.System_Boolean => typeof(bool), + SpecialType.System_Byte => typeof(byte), + SpecialType.System_Char => typeof(char), + SpecialType.System_Decimal => typeof(decimal), + SpecialType.System_Double => typeof(double), + SpecialType.System_Int16 => typeof(short), + SpecialType.System_Int32 => typeof(int), + SpecialType.System_Int64 => typeof(long), + SpecialType.System_IntPtr => typeof(IntPtr), + SpecialType.System_SByte => typeof(sbyte), + SpecialType.System_Single => typeof(float), + SpecialType.System_String => typeof(string), + SpecialType.System_UInt16 => typeof(ushort), + SpecialType.System_UInt32 => typeof(uint), + SpecialType.System_UInt64 => typeof(ulong), + SpecialType.System_UIntPtr => typeof(UIntPtr), + // All types that can be constants should hopefully be handled above. + _ => throw new ArgumentException($"Unexpected SpecialType '{type.SpecialType}'."), + }; + } + private static void AnalyzeGenericMethod( SymbolAnalysisContext context, SyntaxNode dataRowSyntax, @@ -229,44 +268,51 @@ private static void AnalyzeGenericMethod( var parameterTypesSubstitutions = new Dictionary(SymbolEqualityComparer.Default); foreach (IParameterSymbol parameter in methodSymbol.Parameters) { - ITypeSymbol parameterType = parameter.Type; - if (parameterType.Kind != SymbolKind.TypeParameter) + TypedConstant constructorArgument = constructorArguments[parameter.Ordinal]; + + // This happens for [DataRow(null)] which ends up being resolved + // to DataRow(string?[]? stringArrayData) constructor. + // It also happens with [DataRow((object[]?)null)] which resolves + // to the params object[] constructor + // In this case, the argument is simply "null". + if (constructorArgument.Kind == TypedConstantKind.Array && constructorArgument.IsNull) { continue; } - TypedConstant constructorArgument = constructorArguments[parameter.Ordinal]; if (constructorArgument.Type is null) { // That's an error scenario. The compiler will be complaining about something already. continue; } - // This happens for [DataRow(null)] which ends up being resolved - // to DataRow(string?[]? stringArrayData) constructor. - // It also happens with [DataRow((object[]?)null)] which resolves - // to the params object[] constructor - // In this case, the argument is simply "null". - if (constructorArgument.Kind == TypedConstantKind.Array && constructorArgument.IsNull) + Type? argumentType = constructorArgument.Kind == TypedConstantKind.Array + ? GetSystemType(((IArrayTypeSymbol)constructorArgument.Type).ElementType) + : constructorArgument.Value?.GetType(); + + if (argumentType is null) { continue; } - object? argumentValue = constructorArgument.Value; - if (argumentValue is null) + ITypeSymbol parameterType = constructorArgument.Kind == TypedConstantKind.Array + ? ((IArrayTypeSymbol)parameter.Type).ElementType + : parameter.Type; + + if (parameterType.Kind != SymbolKind.TypeParameter) { continue; } if (parameterTypesSubstitutions.TryGetValue(parameterType, out (ITypeSymbol Symbol, Type SystemType) existingType)) { - if (argumentValue.GetType().IsAssignableTo(existingType.SystemType)) + if (argumentType.IsAssignableTo(existingType.SystemType)) { continue; } - else if (existingType.SystemType.IsAssignableTo(argumentValue.GetType())) + else if (existingType.SystemType.IsAssignableTo(argumentType)) { - parameterTypesSubstitutions[parameterType] = (parameterType, argumentValue.GetType()); + parameterTypesSubstitutions[parameterType] = (parameterType, argumentType); } else { @@ -275,7 +321,7 @@ private static void AnalyzeGenericMethod( } else { - parameterTypesSubstitutions.Add(parameterType, (constructorArgument.Type, argumentValue.GetType())); + parameterTypesSubstitutions.Add(parameterType, (constructorArgument.Type, argumentType)); } } diff --git a/src/Analyzers/MSTest.Analyzers/TestMethodShouldBeValidAnalyzer.cs b/src/Analyzers/MSTest.Analyzers/TestMethodShouldBeValidAnalyzer.cs index b24d1ed204..ebf2240831 100644 --- a/src/Analyzers/MSTest.Analyzers/TestMethodShouldBeValidAnalyzer.cs +++ b/src/Analyzers/MSTest.Analyzers/TestMethodShouldBeValidAnalyzer.cs @@ -54,6 +54,32 @@ public override void Initialize(AnalysisContext context) }); } + private static bool IsOrHasTypeParameter(ITypeSymbol type, ITypeParameterSymbol typeParameter) + { + if (SymbolEqualityComparer.Default.Equals(type, typeParameter)) + { + return true; + } + + if (type is IArrayTypeSymbol array) + { + return IsOrHasTypeParameter(array.ElementType, typeParameter); + } + + if (type is INamedTypeSymbol namedType) + { + foreach (ITypeSymbol typeArgument in namedType.TypeArguments) + { + if (IsOrHasTypeParameter(typeArgument, typeParameter)) + { + return true; + } + } + } + + return false; + } + private static void AnalyzeSymbol(SymbolAnalysisContext context, INamedTypeSymbol testMethodAttributeSymbol, INamedTypeSymbol? taskSymbol, INamedTypeSymbol? valueTaskSymbol, bool canDiscoverInternals) { @@ -74,8 +100,9 @@ private static void AnalyzeSymbol(SymbolAnalysisContext context, INamedTypeSymbo { foreach (ITypeParameterSymbol typeParameter in methodSymbol.TypeParameters) { - // If none of the parameters match the type parameter, then that generic type can't be inferred. - if (!methodSymbol.Parameters.Any(p => typeParameter.Equals(p.Type, SymbolEqualityComparer.Default))) + // If none of the parameters contains the type parameter, then that generic type can't be inferred. + // By "contains", we mean if the type parameter is 'T', we could have 'T', 'T[]', or 'List'. + if (!methodSymbol.Parameters.Any(p => IsOrHasTypeParameter(p.Type, typeParameter))) { context.ReportDiagnostic(methodSymbol.CreateDiagnostic(ValidTestMethodSignatureRule, methodSymbol.Name)); } diff --git a/test/IntegrationTests/MSTest.Acceptance.IntegrationTests/GenericTestMethodTests.cs b/test/IntegrationTests/MSTest.Acceptance.IntegrationTests/GenericTestMethodTests.cs index edf1e91609..974131b734 100644 --- a/test/IntegrationTests/MSTest.Acceptance.IntegrationTests/GenericTestMethodTests.cs +++ b/test/IntegrationTests/MSTest.Acceptance.IntegrationTests/GenericTestMethodTests.cs @@ -65,6 +65,12 @@ at .+? failed ParameterizedMethodSimpleParams \(null,"Hello world"\) \(\d+ms\) Cannot create an instance of T\[] because Type\.ContainsGenericParameters is true\. at .+? + failed ParameterizedMethodWithNestedGeneric \(System\.Collections\.Generic\.List`1\[System.String],System\.Collections\.Generic\.List`1\[System.String]\) \(\d+ms\) + Assert\.Fail failed\. Test method 'ParameterizedMethodWithNestedGeneric' did run with first list \[Hello, World] and second list \[Unit, Testing] + at .+? + failed ParameterizedMethodWithNestedGeneric \(System\.Collections\.Generic\.List`1\[System.Int32],System\.Collections\.Generic\.List`1\[System.Int32]\) \(\d+ms\) + Assert\.Fail failed\. Test method 'ParameterizedMethodWithNestedGeneric' did run with first list \[0, 1] and second list \[2, 3] + at .+? """, RegexOptions.Singleline); } @@ -136,6 +142,24 @@ public void ParameterizedMethodTwoGenericParametersAndFourMethodParameters(params T[] parameter) => Assert.Fail($"Test method 'ParameterizedMethodSimple' did run with parameter '{string.Join(",", parameter)}' and type '{typeof(T)}'."); + + [TestMethod] + [DynamicData(nameof(Data))] + public void ParameterizedMethodWithNestedGeneric(List a, List b) + { + Assert.AreEqual(2, a.Count); + Assert.AreEqual(2, b.Count); + Assert.Fail($"Test method 'ParameterizedMethodWithNestedGeneric' did run with first list [{a[0]}, {a[1]}] and second list [{b[0]}, {b[1]}]"); + } + + public static IEnumerable Data + { + get + { + yield return new object[] { new List() { "Hello", "World" }, new List() { "Unit", "Testing" } }; + yield return new object[] { new List() { 0, 1 }, new List() { 2, 3 } }; + } + } } """; } diff --git a/test/UnitTests/MSTest.Analyzers.UnitTests/DataRowShouldBeValidAnalyzerTests.cs b/test/UnitTests/MSTest.Analyzers.UnitTests/DataRowShouldBeValidAnalyzerTests.cs index 37e2ced325..a355034f92 100644 --- a/test/UnitTests/MSTest.Analyzers.UnitTests/DataRowShouldBeValidAnalyzerTests.cs +++ b/test/UnitTests/MSTest.Analyzers.UnitTests/DataRowShouldBeValidAnalyzerTests.cs @@ -652,4 +652,66 @@ await VerifyCS.VerifyAnalyzerAsync( // /0/Test0.cs(34,6): warning MSTEST0014: The type of the generic parameter 'T' could not be inferred. VerifyCS.Diagnostic(DataRowShouldBeValidAnalyzer.GenericTypeArgumentNotResolvedRule).WithLocation(9).WithArguments("T")); } + + [TestMethod] + public async Task WhenMethodIsGenericWithEnumArgument() + { + string code = """ + using System; + using System.Collections.Generic; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + [DataRow(ConsoleColor.Red)] + public void TestMethod(T t) + { + } + } + """; + + await VerifyCS.VerifyAnalyzerAsync(code); + } + + [TestMethod] + public async Task WhenMethodIsNestedGeneric() + { + string code = """ + using System; + using System.Collections.Generic; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + [DataRow(new int[] { 0, 1 })] + public void TestMethodWithGenericIntArray(T[] p) + { + } + + [TestMethod] + [DataRow(new int[] { })] + public void TestMethodWithGenericIntArrayEmpty(T[] p) + { + } + + [TestMethod] + [DataRow(new ConsoleColor[] { ConsoleColor.Green, ConsoleColor.Red })] + public void TestMethodWithGenericEnumArray(T[] p) + { + } + + [TestMethod] + [DataRow(new ConsoleColor[] { })] + public void TestMethodWithGenericEnumArrayEmpty(T[] p) + { + } + } + """; + + await VerifyCS.VerifyAnalyzerAsync(code); + } } diff --git a/test/UnitTests/MSTest.Analyzers.UnitTests/TestMethodShouldBeValidAnalyzerTests.cs b/test/UnitTests/MSTest.Analyzers.UnitTests/TestMethodShouldBeValidAnalyzerTests.cs index 75ba66adbb..05977b87d7 100644 --- a/test/UnitTests/MSTest.Analyzers.UnitTests/TestMethodShouldBeValidAnalyzerTests.cs +++ b/test/UnitTests/MSTest.Analyzers.UnitTests/TestMethodShouldBeValidAnalyzerTests.cs @@ -159,6 +159,31 @@ public void MyTestMethod() await VerifyCS.VerifyCodeFixAsync(code, fixedCode); } + [TestMethod] + public async Task WhenTestMethodIsGeneric_CanBeInferred_NoDiagnostic() + { + string code = """ + using System.Collections.Generic; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod1(T[] t) + { + } + + [TestMethod] + public void TestMethod2(List t) + { + } + } + """; + + await VerifyCS.VerifyCodeFixAsync(code, code); + } + [TestMethod] public async Task WhenTestMethodIsGeneric_Diagnostic() {