Skip to content

Commit

Permalink
* Use custom exception on protected method mock errors and additional…
Browse files Browse the repository at this point in the history
… unit tests to verify error conditions

* Rename unit tests with better description
* Fix unit test failure due to orphaned argument spec when previous setup throws
  • Loading branch information
Jason31569 committed Jan 2, 2025
1 parent 88998b2 commit 18191e2
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 11 deletions.
3 changes: 1 addition & 2 deletions src/NSubstitute/Core/ThreadLocalContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,8 @@ public IList<IArgumentSpecification> DequeueAllArgumentSpecifications()
public IList<IArgumentSpecification> PeekAllArgumentSpecifications()
{
var queue = _argumentSpecifications.Value;
if (queue == null) { throw new SubstituteInternalException("Argument specification queue is null."); }

if (queue.Count > 0)
if (queue?.Count > 0)
{
var items = new IArgumentSpecification[queue.Count];

Expand Down
10 changes: 10 additions & 0 deletions src/NSubstitute/Exceptions/ProtectedMethodNotFoundException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace NSubstitute.Exceptions;

public class ProtectedMethodNotFoundException(string message, Exception? innerException) : SubstituteException(message, innerException)
{
public ProtectedMethodNotFoundException() : this("", null)
{ }

public ProtectedMethodNotFoundException(string message) : this(message, null)
{ }
}
10 changes: 10 additions & 0 deletions src/NSubstitute/Exceptions/ProtectedMethodNotVirtualException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace NSubstitute.Exceptions;

public class ProtectedMethodNotVirtualException(string message, Exception? innerException) : SubstituteException(message, innerException)
{
public ProtectedMethodNotVirtualException() : this("", null)
{ }

public ProtectedMethodNotVirtualException(string message) : this(message, null)
{ }
}
28 changes: 24 additions & 4 deletions src/NSubstitute/Extensions/ProtectedExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public static class ProtectedExtensions
/// <param name="args">The method arguments.</param>
/// <returns>Result object from the method invocation.</returns>
/// <exception cref="NSubstitute.Exceptions.NullSubstituteReferenceException">Substitute - Cannot mock null object</exception>
/// <exception cref="NSubstitute.Exceptions.ProtectedMethodNotFoundException">Error mocking method. Method must be protected virtual and with correct matching arguments and type</exception>
/// <exception cref="System.ArgumentException">Must provide valid protected method name to mock - methodName</exception>
public static object Protected<T>(this T obj, string methodName, params object[] args) where T : class
{
Expand All @@ -28,8 +29,17 @@ public static object Protected<T>(this T obj, string methodName, params object[]
IList<IArgumentSpecification> argTypes = SubstitutionContext.Current.ThreadContext.PeekAllArgumentSpecifications();
MethodInfo mthdInfo = obj.GetType().GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance, Type.DefaultBinder, argTypes.Select(x => x.ForType).ToArray(), null);

if (mthdInfo == null) { throw new Exception($"Method {methodName} not found"); }
if (!mthdInfo.IsVirtual) { throw new Exception($"Method {methodName} is not virtual"); }
if (mthdInfo == null)
{
_ = SubstitutionContext.Current.ThreadContext.DequeueAllArgumentSpecifications();
throw new ProtectedMethodNotFoundException($"No protected virtual method found with signature {methodName}({string.Join(", ", argTypes.Select(x => x.ForType))}) in {obj.GetType().BaseType!.Name}. " +
"Check that the method name and arguments are correct. Public virtual methods must use standard NSubstitute mocking. See the documentation for additional info.");
}
if (!mthdInfo.IsVirtual)
{
_ = SubstitutionContext.Current.ThreadContext.DequeueAllArgumentSpecifications();
throw new ProtectedMethodNotVirtualException($"{mthdInfo} is not virtual. NSubstitute can only work with virtual members of the class that are overridable in the test assembly");
}

return mthdInfo.Invoke(obj, args);
}
Expand All @@ -43,6 +53,7 @@ public static object Protected<T>(this T obj, string methodName, params object[]
/// <param name="args">The method arguments.</param>
/// <returns>WhenCalled&lt;T&gt;.</returns>
/// <exception cref="NSubstitute.Exceptions.NullSubstituteReferenceException">Substitute - Cannot mock null object</exception>
/// <exception cref="NSubstitute.Exceptions.ProtectedMethodNotFoundException">Error mocking method. Method must be protected virtual and with correct matching arguments and type</exception>
/// <exception cref="System.ArgumentException">Must provide valid protected method name to mock - methodName</exception>
public static WhenCalled<T> When<T>(this T obj, string methodName, params object[] args) where T : class
{
Expand All @@ -52,8 +63,17 @@ public static WhenCalled<T> When<T>(this T obj, string methodName, params object
IList<IArgumentSpecification> argTypes = SubstitutionContext.Current.ThreadContext.PeekAllArgumentSpecifications();
MethodInfo mthdInfo = obj.GetType().GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance, Type.DefaultBinder, argTypes.Select(y => y.ForType).ToArray(), null);

if (mthdInfo == null) { throw new Exception($"Method {methodName} not found"); }
if (!mthdInfo.IsVirtual) { throw new Exception($"Method {methodName} is not virtual"); }
if (mthdInfo == null)
{
_ = SubstitutionContext.Current.ThreadContext.DequeueAllArgumentSpecifications();
throw new ProtectedMethodNotFoundException($"No protected virtual method found with signature {methodName}({string.Join(", ", argTypes.Select(x => x.ForType))}) in {obj.GetType().BaseType!.Name}. " +
"Check that the method name and arguments are correct. Public virtual methods must use standard NSubstitute mocking. See the documentation for additional info.");
}
if (!mthdInfo.IsVirtual)
{
_ = SubstitutionContext.Current.ThreadContext.DequeueAllArgumentSpecifications();
throw new ProtectedMethodNotVirtualException($"{mthdInfo} is not virtual. NSubstitute can only work with virtual members of the class that are overridable in the test assembly");
}

return new WhenCalled<T>(SubstitutionContext.Current, obj, x => mthdInfo.Invoke(x, args), MatchArgs.AsSpecifiedInCall);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ public abstract class AnotherClass

protected abstract void ProtectedMethodWithNoReturn(string msg, int i, char j);

public abstract void PublicVirtualMethod();

protected void ProtectedNonVirtualMethod()
{ }

public string DoWork()
{
return ProtectedMethod();
Expand Down
106 changes: 101 additions & 5 deletions tests/NSubstitute.Acceptance.Specs/ProtectedExtensionsTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using NSubstitute.Acceptance.Specs.Infrastructure;
using NSubstitute.Exceptions;
using NSubstitute.Extensions;
using NUnit.Framework;

Expand Down Expand Up @@ -29,8 +30,7 @@ public void Should_mock_and_verify_protected_method_with_arg()
sub.Protected("ProtectedMethod", Arg.Any<int>()).Returns(expectedMsg);

Assert.That(worker.DoMoreWork(sub, 5), Is.EqualTo(expectedMsg));
var a = sub.Received(1);
a.Protected("ProtectedMethod", Arg.Any<int>());
sub.Received(1).Protected("ProtectedMethod", Arg.Any<int>());
}

[Test]
Expand All @@ -47,7 +47,55 @@ public void Should_mock_and_verify_protected_method_with_multiple_args()
}

[Test]
public void Should_mock_and_verify_method_with_no_return_and_no_args()
public void Should_throw_on_mock_null_substitute()
{
Assert.Throws<NullSubstituteReferenceException>(() => (null as AnotherClass).Protected("ProtectedMethod"));
}

[TestCase("")]
[TestCase(" ")]
[TestCase(null)]
public void Should_throw_on_mock_invalid_method_name(string methodName)
{
var sub = Substitute.For<AnotherClass>();

Assert.Throws<ArgumentException>(() => sub.Protected(methodName));
}

[Test]
public void Should_throw_on_mock_method_not_found()
{
var sub = Substitute.For<AnotherClass>();

Assert.Throws<ProtectedMethodNotFoundException>(() => sub.Protected("MethodDoesNotExist"));
}

[Test]
public void Should_throw_on_mock_method_arg_mismatch()
{
var sub = Substitute.For<AnotherClass>();

Assert.Throws<ProtectedMethodNotFoundException>(() => sub.Protected("ProtectedMethod", Arg.Any<IEnumerable<char>>()));
}

[Test]
public void Should_throw_on_mock_public_virtual_method()
{
var sub = Substitute.For<AnotherClass>();

Assert.Throws<ProtectedMethodNotFoundException>(() => sub.Protected("PublicVirtualMethod"));
}

[Test]
public void Should_throw_on_mock_non_virtual()
{
var sub = Substitute.For<AnotherClass>();

Assert.Throws<ProtectedMethodNotVirtualException>(() => sub.Protected("ProtectedNonVirtualMethod"));
}

[Test]
public void Should_mock_and_verify_void_method_and_no_args()
{
var count = 0;
var sub = Substitute.For<AnotherClass>();
Expand All @@ -61,7 +109,7 @@ public void Should_mock_and_verify_method_with_no_return_and_no_args()
}

[Test]
public void Should_mock_and_verify_method_with_no_return_with_arg()
public void Should_mock_and_verify_void_method_with_arg()
{
var count = 0;
var sub = Substitute.For<AnotherClass>();
Expand All @@ -75,7 +123,7 @@ public void Should_mock_and_verify_method_with_no_return_with_arg()
}

[Test]
public void Should_mock_and_verify_method_with_no_return_with_multiple_args()
public void Should_mock_and_verify_void_method_with_multiple_args()
{
var count = 0;
var sub = Substitute.For<AnotherClass>();
Expand All @@ -88,6 +136,54 @@ public void Should_mock_and_verify_method_with_no_return_with_multiple_args()
sub.Received(1).Protected("ProtectedMethodWithNoReturn", Arg.Any<string>(), Arg.Any<int>(), Arg.Any<char>());
}

[Test]
public void Should_throw_on_void_method_null_substitute()
{
Assert.Throws<NullSubstituteReferenceException>(() => (null as AnotherClass).When("ProtectedMethod"));
}

[TestCase("")]
[TestCase(" ")]
[TestCase(null)]
public void Should_throw_on_mock_void_method_invalid_method_name(string methodName)
{
var sub = Substitute.For<AnotherClass>();

Assert.Throws<ArgumentException>(() => sub.When(methodName));
}

[Test]
public void Should_throw_on_mock_void_method_not_found()
{
var sub = Substitute.For<AnotherClass>();

Assert.Throws<ProtectedMethodNotFoundException>(() => sub.When("MethodDoesNotExist"));
}

[Test]
public void Should_throw_on_mock_void_method_arg_mismatch()
{
var sub = Substitute.For<AnotherClass>();

Assert.Throws<ProtectedMethodNotFoundException>(() => sub.When("ProtectedMethod", Arg.Any<IEnumerable<char>>()));
}

[Test]
public void Should_throw_on_mock_public_virtual_void_method()
{
var sub = Substitute.For<AnotherClass>();

Assert.Throws<ProtectedMethodNotFoundException>(() => sub.When("PublicVirtualMethod"));
}

[Test]
public void Should_throw_on_mock_non_virtual_void_method()
{
var sub = Substitute.For<AnotherClass>();

Assert.Throws<ProtectedMethodNotVirtualException>(() => sub.When("ProtectedNonVirtualMethod"));
}

private class Worker
{
internal string DoWork(AnotherClass worker)
Expand Down

0 comments on commit 18191e2

Please sign in to comment.