Skip to content

Commit

Permalink
Improve generic support (#4739)
Browse files Browse the repository at this point in the history
  • Loading branch information
Youssef1313 authored Jan 22, 2025
1 parent 7ab0087 commit afa7b18
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 35 deletions.
70 changes: 55 additions & 15 deletions src/Adapter/MSTest.TestAdapter/Extensions/MethodInfoExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,42 @@ internal static void InvokeAsSynchronousTask(this MethodInfo methodInfo, object?
}
}

private static void InferGenerics(Type parameterType, Type argumentType, List<(Type ParameterType, Type Substitution)> result)
{
if (parameterType.IsGenericMethodParameter())
{
// We found a generic parameter. The argument type should be the substitution for it.
result.Add((parameterType, argumentType));
return;
}

if (!parameterType.ContainsGenericParameters)
{
// We don't have any generics.
return;
}

if (parameterType.GetElementType() is { } parameterTypeElementType &&
argumentType.GetElementType() is { } argumentTypeElementType)
{
// If we have arrays, we need to infer the generic types for the element types.
// For example, if parameterType is `T[]` and argumentType is `string[]`, we need to infer that `T` is `string`.
// So, we call InferGenerics with `T` and `string`.
InferGenerics(parameterTypeElementType, argumentTypeElementType, result);
return;
}
else if (parameterType.GenericTypeArguments.Length == argumentType.GenericTypeArguments.Length)
{
for (int i = 0; i < parameterType.GenericTypeArguments.Length; i++)
{
if (parameterType.GenericTypeArguments[i].ContainsGenericParameters)
{
InferGenerics(parameterType.GenericTypeArguments[i], argumentType.GenericTypeArguments[i], result);
}
}
}
}

// Scenarios to test:
//
// [DataRow(null, "Hello")]
Expand Down Expand Up @@ -228,26 +264,30 @@ private static MethodInfo ConstructGenericMethod(MethodInfo methodInfo, object?[
for (int i = 0; i < parameters.Length; i++)
{
Type parameterType = parameters[i].ParameterType;
if (!parameterType.IsGenericMethodParameter() || arguments[i] is null)
if (!parameterType.ContainsGenericParameters || arguments[i] is null)
{
continue;
}

Type substitution = arguments[i]!/*Very strange nullability warning*/.GetType();
int mapIndexForParameter = GetMapIndexForParameterType(parameterType, map);
Type? existingSubstitution = map[mapIndexForParameter].Substitution;

if (existingSubstitution is null || substitution.IsAssignableFrom(existingSubstitution))
{
map[mapIndexForParameter] = (parameterType, substitution);
}
else if (existingSubstitution.IsAssignableFrom(substitution))
{
// Do nothing. We already have a good existing substitution.
}
else
var result = new List<(Type ParameterType, Type Substitution)>();
InferGenerics(parameterType, arguments[i]!/*Very strange nullability warning*/.GetType(), result);
foreach ((Type genericParameterType, Type substitution) in result)
{
throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, Resource.GenericParameterConflict, parameterType.Name, existingSubstitution, substitution));
int mapIndexForParameter = GetMapIndexForParameterType(genericParameterType, map);
Type? existingSubstitution = map[mapIndexForParameter].Substitution;

if (existingSubstitution is null || substitution.IsAssignableFrom(existingSubstitution))
{
map[mapIndexForParameter] = (genericParameterType, substitution);
}
else if (existingSubstitution.IsAssignableFrom(substitution))
{
// Do nothing. We already have a good existing substitution.
}
else
{
throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, Resource.GenericParameterConflict, parameterType.Name, existingSubstitution, substitution));
}
}
}

Expand Down
82 changes: 64 additions & 18 deletions src/Analyzers/MSTest.Analyzers/DataRowShouldBeValidAnalyzer.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Collections.Immutable;
Expand Down Expand Up @@ -193,10 +193,11 @@ private static void AnalyzeAttribute(SymbolAnalysisContext context, AttributeDat

ITypeSymbol? argumentType = constructorArguments[currentArgumentIndex].Type;
ITypeSymbol paramType = GetParameterType(methodSymbol.Parameters, currentArgumentIndex, constructorArguments.Length);
if (paramType.TypeKind == TypeKind.TypeParameter)
if (paramType.TypeKind == TypeKind.TypeParameter ||
paramType is IArrayTypeSymbol { ElementType.TypeKind: TypeKind.TypeParameter })
{
// That means the actual type cannot be determined. We should have issued a separate
// diagnostic for that in GetParameterTypeSubstitutions call above.
// diagnostic for that in AnalyzeGenericMethod call above.
continue;
}

Expand All @@ -215,6 +216,44 @@ private static void AnalyzeAttribute(SymbolAnalysisContext context, AttributeDat
}
}

private static Type GetSystemType(ITypeSymbol type)
{
if (type.TypeKind == TypeKind.Enum)
{
if (((INamedTypeSymbol)type).EnumUnderlyingType is { } underlyingType)
{
type = underlyingType;
}
else
{
// If this is reachable, it will be an error scenario.
return typeof(int);
}
}

return type.SpecialType switch
{
SpecialType.System_Boolean => typeof(bool),
SpecialType.System_Byte => typeof(byte),
SpecialType.System_Char => typeof(char),
SpecialType.System_Decimal => typeof(decimal),
SpecialType.System_Double => typeof(double),
SpecialType.System_Int16 => typeof(short),
SpecialType.System_Int32 => typeof(int),
SpecialType.System_Int64 => typeof(long),
SpecialType.System_IntPtr => typeof(IntPtr),
SpecialType.System_SByte => typeof(sbyte),
SpecialType.System_Single => typeof(float),
SpecialType.System_String => typeof(string),
SpecialType.System_UInt16 => typeof(ushort),
SpecialType.System_UInt32 => typeof(uint),
SpecialType.System_UInt64 => typeof(ulong),
SpecialType.System_UIntPtr => typeof(UIntPtr),
// All types that can be constants should hopefully be handled above.
_ => throw new ArgumentException($"Unexpected SpecialType '{type.SpecialType}'."),
};
}

private static void AnalyzeGenericMethod(
SymbolAnalysisContext context,
SyntaxNode dataRowSyntax,
Expand All @@ -229,44 +268,51 @@ private static void AnalyzeGenericMethod(
var parameterTypesSubstitutions = new Dictionary<ITypeSymbol, (ITypeSymbol Symbol, Type SystemType)>(SymbolEqualityComparer.Default);
foreach (IParameterSymbol parameter in methodSymbol.Parameters)
{
ITypeSymbol parameterType = parameter.Type;
if (parameterType.Kind != SymbolKind.TypeParameter)
TypedConstant constructorArgument = constructorArguments[parameter.Ordinal];

// This happens for [DataRow(null)] which ends up being resolved
// to DataRow(string?[]? stringArrayData) constructor.
// It also happens with [DataRow((object[]?)null)] which resolves
// to the params object[] constructor
// In this case, the argument is simply "null".
if (constructorArgument.Kind == TypedConstantKind.Array && constructorArgument.IsNull)
{
continue;
}

TypedConstant constructorArgument = constructorArguments[parameter.Ordinal];
if (constructorArgument.Type is null)
{
// That's an error scenario. The compiler will be complaining about something already.
continue;
}

// This happens for [DataRow(null)] which ends up being resolved
// to DataRow(string?[]? stringArrayData) constructor.
// It also happens with [DataRow((object[]?)null)] which resolves
// to the params object[] constructor
// In this case, the argument is simply "null".
if (constructorArgument.Kind == TypedConstantKind.Array && constructorArgument.IsNull)
Type? argumentType = constructorArgument.Kind == TypedConstantKind.Array
? GetSystemType(((IArrayTypeSymbol)constructorArgument.Type).ElementType)
: constructorArgument.Value?.GetType();

if (argumentType is null)
{
continue;
}

object? argumentValue = constructorArgument.Value;
if (argumentValue is null)
ITypeSymbol parameterType = constructorArgument.Kind == TypedConstantKind.Array
? ((IArrayTypeSymbol)parameter.Type).ElementType
: parameter.Type;

if (parameterType.Kind != SymbolKind.TypeParameter)
{
continue;
}

if (parameterTypesSubstitutions.TryGetValue(parameterType, out (ITypeSymbol Symbol, Type SystemType) existingType))
{
if (argumentValue.GetType().IsAssignableTo(existingType.SystemType))
if (argumentType.IsAssignableTo(existingType.SystemType))
{
continue;
}
else if (existingType.SystemType.IsAssignableTo(argumentValue.GetType()))
else if (existingType.SystemType.IsAssignableTo(argumentType))
{
parameterTypesSubstitutions[parameterType] = (parameterType, argumentValue.GetType());
parameterTypesSubstitutions[parameterType] = (parameterType, argumentType);
}
else
{
Expand All @@ -275,7 +321,7 @@ private static void AnalyzeGenericMethod(
}
else
{
parameterTypesSubstitutions.Add(parameterType, (constructorArgument.Type, argumentValue.GetType()));
parameterTypesSubstitutions.Add(parameterType, (constructorArgument.Type, argumentType));
}
}

Expand Down
31 changes: 29 additions & 2 deletions src/Analyzers/MSTest.Analyzers/TestMethodShouldBeValidAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,32 @@ public override void Initialize(AnalysisContext context)
});
}

private static bool IsOrHasTypeParameter(ITypeSymbol type, ITypeParameterSymbol typeParameter)
{
if (SymbolEqualityComparer.Default.Equals(type, typeParameter))
{
return true;
}

if (type is IArrayTypeSymbol array)
{
return IsOrHasTypeParameter(array.ElementType, typeParameter);
}

if (type is INamedTypeSymbol namedType)
{
foreach (ITypeSymbol typeArgument in namedType.TypeArguments)
{
if (IsOrHasTypeParameter(typeArgument, typeParameter))
{
return true;
}
}
}

return false;
}

private static void AnalyzeSymbol(SymbolAnalysisContext context, INamedTypeSymbol testMethodAttributeSymbol, INamedTypeSymbol? taskSymbol,
INamedTypeSymbol? valueTaskSymbol, bool canDiscoverInternals)
{
Expand All @@ -74,8 +100,9 @@ private static void AnalyzeSymbol(SymbolAnalysisContext context, INamedTypeSymbo
{
foreach (ITypeParameterSymbol typeParameter in methodSymbol.TypeParameters)
{
// If none of the parameters match the type parameter, then that generic type can't be inferred.
if (!methodSymbol.Parameters.Any(p => typeParameter.Equals(p.Type, SymbolEqualityComparer.Default)))
// If none of the parameters contains the type parameter, then that generic type can't be inferred.
// By "contains", we mean if the type parameter is 'T', we could have 'T', 'T[]', or 'List<T>'.
if (!methodSymbol.Parameters.Any(p => IsOrHasTypeParameter(p.Type, typeParameter)))
{
context.ReportDiagnostic(methodSymbol.CreateDiagnostic(ValidTestMethodSignatureRule, methodSymbol.Name));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ at .+?
failed ParameterizedMethodSimpleParams \(null,"Hello world"\) \(\d+ms\)
Cannot create an instance of T\[] because Type\.ContainsGenericParameters is true\.
at .+?
failed ParameterizedMethodWithNestedGeneric \(System\.Collections\.Generic\.List`1\[System.String],System\.Collections\.Generic\.List`1\[System.String]\) \(\d+ms\)
Assert\.Fail failed\. Test method 'ParameterizedMethodWithNestedGeneric' did run with first list \[Hello, World] and second list \[Unit, Testing]
at .+?
failed ParameterizedMethodWithNestedGeneric \(System\.Collections\.Generic\.List`1\[System.Int32],System\.Collections\.Generic\.List`1\[System.Int32]\) \(\d+ms\)
Assert\.Fail failed\. Test method 'ParameterizedMethodWithNestedGeneric' did run with first list \[0, 1] and second list \[2, 3]
at .+?
""", RegexOptions.Singleline);
}

Expand Down Expand Up @@ -136,6 +142,24 @@ public void ParameterizedMethodTwoGenericParametersAndFourMethodParameters<T1, T
[DataRow(null, "Hello world")]
public void ParameterizedMethodSimpleParams<T>(params T[] parameter)
=> Assert.Fail($"Test method 'ParameterizedMethodSimple' did run with parameter '{string.Join(",", parameter)}' and type '{typeof(T)}'.");
[TestMethod]
[DynamicData(nameof(Data))]
public void ParameterizedMethodWithNestedGeneric<T>(List<T> a, List<T> b)
{
Assert.AreEqual(2, a.Count);
Assert.AreEqual(2, b.Count);
Assert.Fail($"Test method 'ParameterizedMethodWithNestedGeneric' did run with first list [{a[0]}, {a[1]}] and second list [{b[0]}, {b[1]}]");
}
public static IEnumerable<object[]> Data
{
get
{
yield return new object[] { new List<string>() { "Hello", "World" }, new List<string>() { "Unit", "Testing" } };
yield return new object[] { new List<int>() { 0, 1 }, new List<int>() { 2, 3 } };
}
}
}
""";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -652,4 +652,66 @@ await VerifyCS.VerifyAnalyzerAsync(
// /0/Test0.cs(34,6): warning MSTEST0014: The type of the generic parameter 'T' could not be inferred.
VerifyCS.Diagnostic(DataRowShouldBeValidAnalyzer.GenericTypeArgumentNotResolvedRule).WithLocation(9).WithArguments("T"));
}

[TestMethod]
public async Task WhenMethodIsGenericWithEnumArgument()
{
string code = """
using System;
using System.Collections.Generic;
using Microsoft.VisualStudio.TestTools.UnitTesting;
[TestClass]
public class TestClass
{
[TestMethod]
[DataRow(ConsoleColor.Red)]
public void TestMethod<T>(T t)
{
}
}
""";

await VerifyCS.VerifyAnalyzerAsync(code);
}

[TestMethod]
public async Task WhenMethodIsNestedGeneric()
{
string code = """
using System;
using System.Collections.Generic;
using Microsoft.VisualStudio.TestTools.UnitTesting;
[TestClass]
public class TestClass
{
[TestMethod]
[DataRow(new int[] { 0, 1 })]
public void TestMethodWithGenericIntArray<T>(T[] p)
{
}
[TestMethod]
[DataRow(new int[] { })]
public void TestMethodWithGenericIntArrayEmpty<T>(T[] p)
{
}
[TestMethod]
[DataRow(new ConsoleColor[] { ConsoleColor.Green, ConsoleColor.Red })]
public void TestMethodWithGenericEnumArray<T>(T[] p)
{
}
[TestMethod]
[DataRow(new ConsoleColor[] { })]
public void TestMethodWithGenericEnumArrayEmpty<T>(T[] p)
{
}
}
""";

await VerifyCS.VerifyAnalyzerAsync(code);
}
}
Loading

0 comments on commit afa7b18

Please sign in to comment.