From 4d258a28aba054ea18785d36b4bdd83da023aefb Mon Sep 17 00:00:00 2001 From: Marco Antonio Regueira Date: Sun, 11 Aug 2024 06:33:47 +0200 Subject: [PATCH] Feature: Enable call forwarding and substitution for non virtual methods or sealed classes implementing an interface. (#700) How to use: var substitute = Substitute.ForTypeForwardingTo (argsList); In this case, it doesn't matter if methods are virtual or not; it will intercept all calls since we will be working with an interface all the time. For Limitations: Overriding virtual methods effectively replaces its implementation both for internal and external calls. With this implementation NSubstitute will only intercept calls made by client classes using the interface. Calls made from inside the object itself to its own method, will hit the actual implementation. --- src/NSubstitute/Core/IProxyFactory.cs | 2 +- src/NSubstitute/Core/SubstituteFactory.cs | 8 +- .../Exceptions/TypeForwardingException.cs | 21 ++++ .../CastleDynamicProxyFactory.cs | 62 ++++++++++- .../CastleInvocationMapper.cs | 3 +- .../DelegateProxy/DelegateProxyFactory.cs | 4 +- src/NSubstitute/Proxies/ProxyFactory.cs | 6 +- src/NSubstitute/Substitute.cs | 20 ++++ ...ngForConcreteTypesAndMultipleInterfaces.cs | 30 +++++ .../TypeForwarding.cs | 104 ++++++++++++++++++ 10 files changed, 242 insertions(+), 18 deletions(-) create mode 100644 src/NSubstitute/Exceptions/TypeForwardingException.cs create mode 100644 tests/NSubstitute.Acceptance.Specs/TypeForwarding.cs diff --git a/src/NSubstitute/Core/IProxyFactory.cs b/src/NSubstitute/Core/IProxyFactory.cs index 541aad90..31cd3ed8 100644 --- a/src/NSubstitute/Core/IProxyFactory.cs +++ b/src/NSubstitute/Core/IProxyFactory.cs @@ -2,5 +2,5 @@ namespace NSubstitute.Core; public interface IProxyFactory { - object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments); + object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, bool isPartial, object?[]? constructorArguments); } \ No newline at end of file diff --git a/src/NSubstitute/Core/SubstituteFactory.cs b/src/NSubstitute/Core/SubstituteFactory.cs index 76a9f652..e55c2ffd 100644 --- a/src/NSubstitute/Core/SubstituteFactory.cs +++ b/src/NSubstitute/Core/SubstituteFactory.cs @@ -14,7 +14,7 @@ public class SubstituteFactory(ISubstituteStateFactory substituteStateFactory, I /// public object Create(Type[] typesToProxy, object?[] constructorArguments) { - return Create(typesToProxy, constructorArguments, callBaseByDefault: false); + return Create(typesToProxy, constructorArguments, callBaseByDefault: false, isPartial: false); } /// @@ -33,10 +33,10 @@ public object CreatePartial(Type[] typesToProxy, object?[] constructorArguments) throw new CanNotPartiallySubForInterfaceOrDelegateException(primaryProxyType); } - return Create(typesToProxy, constructorArguments, callBaseByDefault: true); + return Create(typesToProxy, constructorArguments, callBaseByDefault: true, isPartial: true); } - private object Create(Type[] typesToProxy, object?[] constructorArguments, bool callBaseByDefault) + private object Create(Type[] typesToProxy, object?[] constructorArguments, bool callBaseByDefault, bool isPartial) { var substituteState = substituteStateFactory.Create(this); substituteState.CallBaseConfiguration.CallBaseByDefault = callBaseByDefault; @@ -46,7 +46,7 @@ private object Create(Type[] typesToProxy, object?[] constructorArguments, bool var callRouter = callRouterFactory.Create(substituteState, canConfigureBaseCalls); var additionalTypes = typesToProxy.Where(x => x != primaryProxyType).ToArray(); - var proxy = proxyFactory.GenerateProxy(callRouter, primaryProxyType, additionalTypes, constructorArguments); + var proxy = proxyFactory.GenerateProxy(callRouter, primaryProxyType, additionalTypes, isPartial, constructorArguments); return proxy; } diff --git a/src/NSubstitute/Exceptions/TypeForwardingException.cs b/src/NSubstitute/Exceptions/TypeForwardingException.cs new file mode 100644 index 00000000..5ddf2703 --- /dev/null +++ b/src/NSubstitute/Exceptions/TypeForwardingException.cs @@ -0,0 +1,21 @@ +namespace NSubstitute.Exceptions; + +public abstract class TypeForwardingException(string message) : SubstituteException(message) +{ +} + +public sealed class CanNotForwardCallsToClassNotImplementingInterfaceException(Type type) : TypeForwardingException(DescribeProblem(type)) +{ + private static string DescribeProblem(Type type) + { + return string.Format("The provided class '{0}' doesn't implement all requested interfaces. ", type.Name); + } +} + +public sealed class CanNotForwardCallsToAbstractClassException(Type type) : TypeForwardingException(DescribeProblem(type)) +{ + private static string DescribeProblem(Type type) + { + return string.Format("The provided class '{0}' is abstract. ", type.Name); + } +} diff --git a/src/NSubstitute/Proxies/CastleDynamicProxy/CastleDynamicProxyFactory.cs b/src/NSubstitute/Proxies/CastleDynamicProxy/CastleDynamicProxyFactory.cs index e1c0d1ef..a445f42f 100644 --- a/src/NSubstitute/Proxies/CastleDynamicProxy/CastleDynamicProxyFactory.cs +++ b/src/NSubstitute/Proxies/CastleDynamicProxy/CastleDynamicProxyFactory.cs @@ -10,14 +10,14 @@ public class CastleDynamicProxyFactory(ICallFactory callFactory, IArgumentSpecif private readonly ProxyGenerator _proxyGenerator = new ProxyGenerator(); private readonly AllMethodsExceptCallRouterCallsHook _allMethodsExceptCallRouterCallsHook = new AllMethodsExceptCallRouterCallsHook(); - public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments) + public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, bool isPartial, object?[]? constructorArguments) { return typeToProxy.IsDelegate() ? GenerateDelegateProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments) - : GenerateTypeProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments); + : GenerateTypeProxy(callRouter, typeToProxy, additionalInterfaces, isPartial, constructorArguments); } - private object GenerateTypeProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments) + private object GenerateTypeProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, bool isPartial, object?[]? constructorArguments) { VerifyClassHasNotBeenPassedAsAnAdditionalInterface(additionalInterfaces); @@ -31,7 +31,8 @@ private object GenerateTypeProxy(ICallRouter callRouter, Type typeToProxy, Type[ additionalInterfaces, constructorArguments, [proxyIdInterceptor, forwardingInterceptor], - proxyGenerationOptions); + proxyGenerationOptions, + isPartial); forwardingInterceptor.SwitchToFullDispatchMode(); return proxy; @@ -54,7 +55,8 @@ private object GenerateDelegateProxy(ICallRouter callRouter, Type delegateType, additionalInterfaces: null, constructorArguments: null, interceptors: [proxyIdInterceptor, forwardingInterceptor], - proxyGenerationOptions); + proxyGenerationOptions, + isPartial: false); forwardingInterceptor.SwitchToFullDispatchMode(); @@ -75,8 +77,13 @@ private CastleForwardingInterceptor CreateForwardingInterceptor(ICallRouter call private object CreateProxyUsingCastleProxyGenerator(Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments, IInterceptor[] interceptors, - ProxyGenerationOptions proxyGenerationOptions) + ProxyGenerationOptions proxyGenerationOptions, + bool isPartial) { + if (isPartial) + return CreatePartialProxy(typeToProxy, additionalInterfaces, constructorArguments, interceptors, proxyGenerationOptions, isPartial); + + if (typeToProxy.GetTypeInfo().IsInterface) { VerifyNoConstructorArgumentsGivenForInterface(constructorArguments); @@ -96,6 +103,7 @@ private object CreateProxyUsingCastleProxyGenerator(Type typeToProxy, Type[]? ad additionalInterfaces = interfaces; } + return _proxyGenerator.CreateClassProxy(typeToProxy, additionalInterfaces, proxyGenerationOptions, @@ -103,6 +111,32 @@ private object CreateProxyUsingCastleProxyGenerator(Type typeToProxy, Type[]? ad interceptors); } + private object CreatePartialProxy(Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments, IInterceptor[] interceptors, ProxyGenerationOptions proxyGenerationOptions, bool isPartial) + { + if (typeToProxy.GetTypeInfo().IsClass && + additionalInterfaces != null && + additionalInterfaces.Any()) + { + VerifyClassIsNotAbstract(typeToProxy); + VerifyClassImplementsAllInterfaces(typeToProxy, additionalInterfaces); + + var targetObject = Activator.CreateInstance(typeToProxy, constructorArguments); + typeToProxy = additionalInterfaces.First(); + + return _proxyGenerator.CreateInterfaceProxyWithTarget(typeToProxy, + additionalInterfaces, + target: targetObject, + options: proxyGenerationOptions, + interceptors: interceptors); + } + + return _proxyGenerator.CreateClassProxy(typeToProxy, + additionalInterfaces, + proxyGenerationOptions, + constructorArguments, + interceptors); + } + private ProxyGenerationOptions GetOptionsToMixinCallRouterProvider(ICallRouter callRouter) { var options = new ProxyGenerationOptions(_allMethodsExceptCallRouterCallsHook); @@ -116,6 +150,22 @@ private ProxyGenerationOptions GetOptionsToMixinCallRouterProvider(ICallRouter c return options; } + private static void VerifyClassImplementsAllInterfaces(Type classType, IEnumerable additionalInterfaces) + { + if (!additionalInterfaces.All(x => x.GetTypeInfo().IsAssignableFrom(classType.GetTypeInfo()))) + { + throw new CanNotForwardCallsToClassNotImplementingInterfaceException(classType); + } + } + + private static void VerifyClassIsNotAbstract(Type classType) + { + if (classType.GetTypeInfo().IsAbstract) + { + throw new CanNotForwardCallsToAbstractClassException(classType); + } + } + private static void VerifyNoConstructorArgumentsGivenForInterface(object?[]? constructorArguments) { if (HasItems(constructorArguments)) diff --git a/src/NSubstitute/Proxies/CastleDynamicProxy/CastleInvocationMapper.cs b/src/NSubstitute/Proxies/CastleDynamicProxy/CastleInvocationMapper.cs index 7ebac3e9..ddc15405 100644 --- a/src/NSubstitute/Proxies/CastleDynamicProxy/CastleInvocationMapper.cs +++ b/src/NSubstitute/Proxies/CastleDynamicProxy/CastleInvocationMapper.cs @@ -10,8 +10,7 @@ public virtual ICall Map(IInvocation castleInvocation) Func? baseMethod = null; if (castleInvocation.InvocationTarget != null && castleInvocation.MethodInvocationTarget.IsVirtual && - !castleInvocation.MethodInvocationTarget.IsAbstract && - !castleInvocation.MethodInvocationTarget.IsFinal) + !castleInvocation.MethodInvocationTarget.IsAbstract) { baseMethod = CreateBaseResultInvocation(castleInvocation); } diff --git a/src/NSubstitute/Proxies/DelegateProxy/DelegateProxyFactory.cs b/src/NSubstitute/Proxies/DelegateProxy/DelegateProxyFactory.cs index 9e38a1c1..66ee9e4e 100644 --- a/src/NSubstitute/Proxies/DelegateProxy/DelegateProxyFactory.cs +++ b/src/NSubstitute/Proxies/DelegateProxy/DelegateProxyFactory.cs @@ -8,9 +8,9 @@ public class DelegateProxyFactory(CastleDynamicProxyFactory objectProxyFactory) { private readonly CastleDynamicProxyFactory _castleObjectProxyFactory = objectProxyFactory ?? throw new ArgumentNullException(nameof(objectProxyFactory)); - public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments) + public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, bool isPartial, object?[]? constructorArguments) { // Castle factory can now resolve delegate proxies as well. - return _castleObjectProxyFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments); + return _castleObjectProxyFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, isPartial, constructorArguments); } } \ No newline at end of file diff --git a/src/NSubstitute/Proxies/ProxyFactory.cs b/src/NSubstitute/Proxies/ProxyFactory.cs index f107ae85..c93ee284 100644 --- a/src/NSubstitute/Proxies/ProxyFactory.cs +++ b/src/NSubstitute/Proxies/ProxyFactory.cs @@ -5,11 +5,11 @@ namespace NSubstitute.Proxies; [Obsolete("This class is deprecated and will be removed in future versions of the product.")] public class ProxyFactory(IProxyFactory delegateFactory, IProxyFactory dynamicProxyFactory) : IProxyFactory { - public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments) + public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, bool isPartial, object?[]? constructorArguments) { var isDelegate = typeToProxy.IsDelegate(); return isDelegate - ? delegateFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments) - : dynamicProxyFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments); + ? delegateFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, isPartial, constructorArguments) + : dynamicProxyFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, isPartial, constructorArguments); } } \ No newline at end of file diff --git a/src/NSubstitute/Substitute.cs b/src/NSubstitute/Substitute.cs index f1d471e6..c1710106 100644 --- a/src/NSubstitute/Substitute.cs +++ b/src/NSubstitute/Substitute.cs @@ -89,4 +89,24 @@ public static T ForPartsOf(params object[] constructorArguments) var substituteFactory = SubstitutionContext.Current.SubstituteFactory; return (T)substituteFactory.CreatePartial([typeof(T)], constructorArguments); } + + /// + /// Creates a proxy for a class that implements an interface, forwarding methods and properties to an instance of the class, effectively mimicking a real instance. + /// Both the interface and the class must be provided as parameters. + /// The proxy will log calls made to the interface members and delegate them to an instance of the class. Specific members can be substituted + /// by using When(() => call).DoNotCallBase() or by + /// setting a value to return value for that member. + /// This extension supports sealed classes and non-virtual members, with some limitations. Since the substituted method is non-virtual, internal calls within the object will invoke the original implementation and will not be logged. + /// + /// The interface the substitute will implement. + /// The class type implementing the interface. Must be a class; not a delegate or interface. + /// + /// An object implementing the selected interface. Calls will be forwarded to the actuall methods, but allows parts to be selectively + /// overridden via `Returns` and `When..DoNotCallBase`. + public static TInterface ForTypeForwardingTo(params object[] constructorArguments) + where TInterface : class + { + var substituteFactory = SubstitutionContext.Current.SubstituteFactory; + return (TInterface)substituteFactory.CreatePartial([typeof(TInterface), typeof(TClass)], constructorArguments); + } } \ No newline at end of file diff --git a/tests/NSubstitute.Acceptance.Specs/SubbingForConcreteTypesAndMultipleInterfaces.cs b/tests/NSubstitute.Acceptance.Specs/SubbingForConcreteTypesAndMultipleInterfaces.cs index 26d4d778..d9985987 100644 --- a/tests/NSubstitute.Acceptance.Specs/SubbingForConcreteTypesAndMultipleInterfaces.cs +++ b/tests/NSubstitute.Acceptance.Specs/SubbingForConcreteTypesAndMultipleInterfaces.cs @@ -1,4 +1,5 @@ using NUnit.Framework; +using NUnit.Framework.Legacy; namespace NSubstitute.Acceptance.Specs; @@ -31,6 +32,30 @@ public void Can_sub_for_concrete_type_and_implement_other_interfaces() subAsIFirst.Received().First(); } + [Test] + public void Can_sub_for_abstract_type_and_implement_other_two_interfaces() + { + // test from docs + var substitute = Substitute.For([typeof(IFirst), typeof(ISecond), typeof(ClassWithCtorArgs)], + ["hello world", 5]); + + ClassicAssert.IsInstanceOf(substitute); + ClassicAssert.IsInstanceOf(substitute); + ClassicAssert.IsInstanceOf(substitute); + } + + [Test] + public void Can_sub_for_concrete_type_and_implement_other_two_interfaces() + { + // test from docs + var substitute = Substitute.For([typeof(IFirst), typeof(ISecond), typeof(ConcreteClassWithCtorArgs)], + ["hello world", 5]); + + ClassicAssert.IsInstanceOf(substitute); + ClassicAssert.IsInstanceOf(substitute); + ClassicAssert.IsInstanceOf(substitute); + } + [Test] public void Partial_sub() { @@ -90,8 +115,13 @@ public class Partial public virtual int Number() { return -1; } public int GetNumberPlusOne() { return Number() + 1; } } + public abstract class ClassWithCtorArgs(string s, int a) { public string StringFromCtorArg { get; set; } = s; public int IntFromCtorArg { get; set; } = a; } + + public class ConcreteClassWithCtorArgs(string s, int a) : ClassWithCtorArgs(s, a) + { + } } \ No newline at end of file diff --git a/tests/NSubstitute.Acceptance.Specs/TypeForwarding.cs b/tests/NSubstitute.Acceptance.Specs/TypeForwarding.cs new file mode 100644 index 00000000..af4a8fa9 --- /dev/null +++ b/tests/NSubstitute.Acceptance.Specs/TypeForwarding.cs @@ -0,0 +1,104 @@ +using NSubstitute.Exceptions; +using NSubstitute.Extensions; +using NUnit.Framework; + +namespace NSubstitute.Acceptance.Specs; + +public class TypeForwarding +{ + [Test] + public void UseImplementedNonVirtualMethod() + { + var testAbstractClass = Substitute.ForTypeForwardingTo(); + Assert.That(testAbstractClass.MethodReturnsSameInt(1), Is.EqualTo(1)); + Assert.That(testAbstractClass.CalledTimes, Is.EqualTo(1)); + testAbstractClass.Received().MethodReturnsSameInt(1); + Assert.That(testAbstractClass.CalledTimes, Is.EqualTo(1)); + } + + [Test] + public void UseSubstitutedNonVirtualMethod() + { + var testInterface = Substitute.ForTypeForwardingTo(); + testInterface.Configure().MethodReturnsSameInt(1).Returns(2); + Assert.That(testInterface.MethodReturnsSameInt(1), Is.EqualTo(2)); + Assert.That(testInterface.MethodReturnsSameInt(3), Is.EqualTo(3)); + testInterface.ReceivedWithAnyArgs(2).MethodReturnsSameInt(default); + Assert.That(testInterface.CalledTimes, Is.EqualTo(1)); + } + + [Test] + public void UseSubstitutedNonVirtualMethodHonorsDoNotCallBase() + { + var testInterface = Substitute.ForTypeForwardingTo(); + testInterface.Configure().MethodReturnsSameInt(1).Returns(2); + testInterface.WhenForAnyArgs(x => x.MethodReturnsSameInt(default)).DoNotCallBase(); + Assert.That(testInterface.MethodReturnsSameInt(1), Is.EqualTo(2)); + Assert.That(testInterface.MethodReturnsSameInt(3), Is.EqualTo(0)); + testInterface.ReceivedWithAnyArgs(2).MethodReturnsSameInt(default); + Assert.That(testInterface.CalledTimes, Is.EqualTo(0)); + } + + [Test] + public void PartialSubstituteCallsConstructorWithParameters() + { + var testInterface = Substitute.ForTypeForwardingTo(50); + Assert.That(testInterface.MethodReturnsSameInt(1), Is.EqualTo(1)); + Assert.That(testInterface.CalledTimes, Is.EqualTo(51)); + } + + [Test] + public void PartialSubstituteFailsIfClassDoesntImplementInterface() + { + Assert.Throws( + () => Substitute.ForTypeForwardingTo()); + } + + [Test] + public void PartialSubstituteFailsIfClassIsAbstract() + { + Assert.Throws( + () => Substitute.ForTypeForwardingTo(), "The provided class is abstract."); + } + + public interface ITestInterface + { + public int CalledTimes { get; set; } + + void VoidTestMethod(); + int TestMethodReturnsInt(); + int MethodReturnsSameInt(int i); + } + + public sealed class TestSealedNonVirtualClass : ITestInterface + { + public TestSealedNonVirtualClass(int initialCounter) => CalledTimes = initialCounter; + public TestSealedNonVirtualClass() { } + + public int CalledTimes { get; set; } + + public int TestMethodReturnsInt() => throw new NotImplementedException(); + + public void VoidTestMethod() => throw new NotImplementedException(); + public int MethodReturnsSameInt(int i) + { + CalledTimes++; + return i; + } + } + + public abstract class TestAbstractClassWithInterface : ITestInterface + { + public int CalledTimes { get; set; } + + public abstract int MethodReturnsSameInt(int i); + + public abstract int TestMethodReturnsInt(); + + public abstract void VoidTestMethod(); + } + + public class TestRandomConcreteClass { } + + public abstract class TestAbstractClass { } +} \ No newline at end of file