Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to mock protected methods with and without return value #845

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/NSubstitute/Core/IThreadLocalContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ public interface IThreadLocalContext
void EnqueueArgumentSpecification(IArgumentSpecification spec);
IList<IArgumentSpecification> DequeueAllArgumentSpecifications();

/// <summary>
/// Peeks into the argument specifications
/// </summary>
/// <returns>Enqueued argument specifications</returns>
IList<IArgumentSpecification> PeekAllArgumentSpecifications();

void SetPendingRaisingEventArgumentsFactory(Func<ICall, object?[]> getArguments);
/// <summary>
/// Returns the previously set arguments factory and resets the stored value.
Expand Down
17 changes: 17 additions & 0 deletions src/NSubstitute/Core/ThreadLocalContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,23 @@ public IList<IArgumentSpecification> DequeueAllArgumentSpecifications()
return queue;
}

/// <inheritdoc/>
public IList<IArgumentSpecification> PeekAllArgumentSpecifications()
{
var queue = _argumentSpecifications.Value;

if (queue?.Count > 0)
{
var items = new IArgumentSpecification[queue.Count];

queue.CopyTo(items, 0);

return items;
}

return EmptySpecifications;
}

public void SetPendingRaisingEventArgumentsFactory(Func<ICall, object?[]> getArguments)
{
_getArgumentsForRaisingEvent.Value = getArguments;
Expand Down
10 changes: 10 additions & 0 deletions src/NSubstitute/Exceptions/ProtectedMethodNotFoundException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace NSubstitute.Exceptions;

public class ProtectedMethodNotFoundException(string message, Exception? innerException) : SubstituteException(message, innerException)
{
public ProtectedMethodNotFoundException() : this("", null)
{ }

public ProtectedMethodNotFoundException(string message) : this(message, null)
{ }
}
10 changes: 10 additions & 0 deletions src/NSubstitute/Exceptions/ProtectedMethodNotVirtualException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace NSubstitute.Exceptions;

public class ProtectedMethodNotVirtualException(string message, Exception? innerException) : SubstituteException(message, innerException)
{
public ProtectedMethodNotVirtualException() : this("", null)
{ }

public ProtectedMethodNotVirtualException(string message) : this(message, null)
{ }
}
80 changes: 80 additions & 0 deletions src/NSubstitute/Extensions/ProtectedExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using System.Reflection;
using NSubstitute.Core;
using NSubstitute.Core.Arguments;
using NSubstitute.Exceptions;

// Disable nullability for client API, so it does not affect clients.
#nullable disable annotations

namespace NSubstitute.Extensions;

public static class ProtectedExtensions
{
/// <summary>
/// Configure behavior for a protected method with return value
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="obj">The object.</param>
/// <param name="methodName">Name of the method.</param>
/// <param name="args">The method arguments.</param>
/// <returns>Result object from the method invocation.</returns>
/// <exception cref="NSubstitute.Exceptions.NullSubstituteReferenceException">Substitute - Cannot mock null object</exception>
/// <exception cref="NSubstitute.Exceptions.ProtectedMethodNotFoundException">Error mocking method. Method must be protected virtual and with correct matching arguments and type</exception>
/// <exception cref="System.ArgumentException">Must provide valid protected method name to mock - methodName</exception>
public static object Protected<T>(this T obj, string methodName, params object[] args) where T : class
{
if (obj == null) { throw new NullSubstituteReferenceException(); }
if (string.IsNullOrWhiteSpace(methodName)) { throw new ArgumentException("Must provide valid protected method name to mock", nameof(methodName)); }

IList<IArgumentSpecification> argTypes = SubstitutionContext.Current.ThreadContext.PeekAllArgumentSpecifications();
MethodInfo mthdInfo = obj.GetType().GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance, Type.DefaultBinder, argTypes.Select(x => x.ForType).ToArray(), null);

if (mthdInfo == null)
{
_ = SubstitutionContext.Current.ThreadContext.DequeueAllArgumentSpecifications();
throw new ProtectedMethodNotFoundException($"No protected virtual method found with signature {methodName}({string.Join(", ", argTypes.Select(x => x.ForType))}) in {obj.GetType().BaseType!.Name}. " +
"Check that the method name and arguments are correct. Public virtual methods must use standard NSubstitute mocking. See the documentation for additional info.");
}
if (!mthdInfo.IsVirtual)
{
_ = SubstitutionContext.Current.ThreadContext.DequeueAllArgumentSpecifications();
throw new ProtectedMethodNotVirtualException($"{mthdInfo} is not virtual. NSubstitute can only work with virtual members of the class that are overridable in the test assembly");
}

return mthdInfo.Invoke(obj, args);
}

/// <summary>
/// Configure behavior for a protected method with no return vlaue
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="obj">The object.</param>
/// <param name="methodName">Name of the method.</param>
/// <param name="args">The method arguments.</param>
/// <returns>WhenCalled&lt;T&gt;.</returns>
/// <exception cref="NSubstitute.Exceptions.NullSubstituteReferenceException">Substitute - Cannot mock null object</exception>
/// <exception cref="NSubstitute.Exceptions.ProtectedMethodNotFoundException">Error mocking method. Method must be protected virtual and with correct matching arguments and type</exception>
/// <exception cref="System.ArgumentException">Must provide valid protected method name to mock - methodName</exception>
public static WhenCalled<T> When<T>(this T obj, string methodName, params object[] args) where T : class
Jason31569 marked this conversation as resolved.
Show resolved Hide resolved
{
if (obj == null) { throw new NullSubstituteReferenceException(); }
if (string.IsNullOrWhiteSpace(methodName)) { throw new ArgumentException("Must provide valid protected method name to mock", nameof(methodName)); }

IList<IArgumentSpecification> argTypes = SubstitutionContext.Current.ThreadContext.PeekAllArgumentSpecifications();
MethodInfo mthdInfo = obj.GetType().GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance, Type.DefaultBinder, argTypes.Select(y => y.ForType).ToArray(), null);

if (mthdInfo == null)
{
_ = SubstitutionContext.Current.ThreadContext.DequeueAllArgumentSpecifications();
throw new ProtectedMethodNotFoundException($"No protected virtual method found with signature {methodName}({string.Join(", ", argTypes.Select(x => x.ForType))}) in {obj.GetType().BaseType!.Name}. " +
"Check that the method name and arguments are correct. Public virtual methods must use standard NSubstitute mocking. See the documentation for additional info.");
}
if (!mthdInfo.IsVirtual)
{
_ = SubstitutionContext.Current.ThreadContext.DequeueAllArgumentSpecifications();
throw new ProtectedMethodNotVirtualException($"{mthdInfo} is not virtual. NSubstitute can only work with virtual members of the class that are overridable in the test assembly");
}

return new WhenCalled<T>(SubstitutionContext.Current, obj, x => mthdInfo.Invoke(x, args), MatchArgs.AsSpecifiedInCall);
}
}
51 changes: 51 additions & 0 deletions tests/NSubstitute.Acceptance.Specs/Infrastructure/AnotherClass.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
namespace NSubstitute.Acceptance.Specs.Infrastructure;

public abstract class AnotherClass
{
protected abstract string ProtectedMethod();

protected abstract string ProtectedMethod(int i);

protected abstract string ProtectedMethod(string msg, int i, char j);

protected abstract void ProtectedMethodWithNoReturn();

protected abstract void ProtectedMethodWithNoReturn(int i);

protected abstract void ProtectedMethodWithNoReturn(string msg, int i, char j);

public abstract void PublicVirtualMethod();

protected void ProtectedNonVirtualMethod()
{ }

public string DoWork()
{
return ProtectedMethod();
}

public string DoWork(int i)
{
return ProtectedMethod(i);
}

public string DoWork(string msg, int i, char j)
{
return ProtectedMethod(msg, i, j);
}

public void DoVoidWork()
{
ProtectedMethodWithNoReturn();
}

public void DoVoidWork(int i)
{
ProtectedMethodWithNoReturn(i);
}

public void DoVoidWork(string msg, int i, char j)
{
ProtectedMethodWithNoReturn(msg, i, j);
}
}
Loading