Skip to content

Commit

Permalink
Avoid stack overflow due to deep recursion on long chain of calls. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
AlekseyTs authored Apr 29, 2023
1 parent e3f89a4 commit 53ac77e
Show file tree
Hide file tree
Showing 21 changed files with 1,106 additions and 292 deletions.
56 changes: 53 additions & 3 deletions src/Compilers/CSharp/Portable/Binder/Binder_Invocation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,67 @@ private BoundExpression BindInvocationExpression(
BindArgumentsAndNames(node.ArgumentList, diagnostics, analyzedArguments, allowArglist: false);
result = BindArgListOperator(node, diagnostics, analyzedArguments);
}
else if (receiverIsInvocation(node, out InvocationExpressionSyntax nested))
{
var invocations = ArrayBuilder<InvocationExpressionSyntax>.GetInstance();

invocations.Push(node);
node = nested;
while (receiverIsInvocation(node, out nested))
{
invocations.Push(node);
node = nested;
}

BoundExpression boundExpression = BindMethodGroup(node.Expression, invoked: true, indexed: false, diagnostics: diagnostics);

while (true)
{
result = bindArgumentsAndInvocation(node, boundExpression, analyzedArguments, diagnostics);
nested = node;

if (!invocations.TryPop(out node))
{
break;
}

Debug.Assert(node.Expression.Kind() is SyntaxKind.SimpleMemberAccessExpression);
var memberAccess = (MemberAccessExpressionSyntax)node.Expression;
analyzedArguments.Clear();
VerifyUnchecked(nested, diagnostics, result); // BindExpression does this after calling BindExpressionInternal
boundExpression = BindMemberAccessWithBoundLeft(memberAccess, result, memberAccess.Name, memberAccess.OperatorToken, invoked: true, indexed: false, diagnostics);
}

invocations.Free();
}
else
{
BoundExpression boundExpression = BindMethodGroup(node.Expression, invoked: true, indexed: false, diagnostics: diagnostics);
result = bindArgumentsAndInvocation(node, boundExpression, analyzedArguments, diagnostics);
}

analyzedArguments.Free();
return result;

BoundExpression bindArgumentsAndInvocation(InvocationExpressionSyntax node, BoundExpression boundExpression, AnalyzedArguments analyzedArguments, BindingDiagnosticBag diagnostics)
{
boundExpression = CheckValue(boundExpression, BindValueKind.RValueOrMethodGroup, diagnostics);
string name = boundExpression.Kind == BoundKind.MethodGroup ? GetName(node.Expression) : null;
BindArgumentsAndNames(node.ArgumentList, diagnostics, analyzedArguments, allowArglist: true);
result = BindInvocationExpression(node, node.Expression, name, boundExpression, analyzedArguments, diagnostics);
return BindInvocationExpression(node, node.Expression, name, boundExpression, analyzedArguments, diagnostics);
}

analyzedArguments.Free();
return result;
static bool receiverIsInvocation(InvocationExpressionSyntax node, out InvocationExpressionSyntax nested)
{
if (node.Expression is MemberAccessExpressionSyntax { Expression: InvocationExpressionSyntax receiver, RawKind: (int)SyntaxKind.SimpleMemberAccessExpression } && !receiver.MayBeNameofOperator())
{
nested = receiver;
return true;
}

nested = null;
return false;
}
}

private BoundExpression BindArgListOperator(InvocationExpressionSyntax node, BindingDiagnosticBag diagnostics, AnalyzedArguments analyzedArguments)
Expand Down
53 changes: 44 additions & 9 deletions src/Compilers/CSharp/Portable/Binder/ExpressionVariableFinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,6 @@ protected void FindExpressionVariables(
_variablesBuilder = save;
}

public override void Visit(SyntaxNode node)
{
if (node != null)
{
// no stackguard
((CSharpSyntaxNode)node).Accept(this);
}
}

public override void VisitSwitchExpression(SwitchExpressionSyntax node)
{
Visit(node.GoverningExpression);
Expand Down Expand Up @@ -348,6 +339,50 @@ public override void VisitBinaryExpression(BinaryExpressionSyntax node)
operands.Free();
}

public override void VisitInvocationExpression(InvocationExpressionSyntax node)
{
if (receiverIsInvocation(node, out InvocationExpressionSyntax nested))
{
var invocations = ArrayBuilder<InvocationExpressionSyntax>.GetInstance();

invocations.Push(node);

node = nested;
while (receiverIsInvocation(node, out nested))
{
invocations.Push(node);
node = nested;
}

Visit(node.Expression);

do
{
Visit(node.ArgumentList);
}
while (invocations.TryPop(out node));

invocations.Free();
}
else
{
Visit(node.Expression);
Visit(node.ArgumentList);
}

static bool receiverIsInvocation(InvocationExpressionSyntax node, out InvocationExpressionSyntax nested)
{
if (node.Expression is MemberAccessExpressionSyntax { Expression: InvocationExpressionSyntax receiver })
{
nested = receiver;
return true;
}

nested = null;
return false;
}
}

public override void VisitDeclarationExpression(DeclarationExpressionSyntax node)
{
var argumentSyntax = node.Parent as ArgumentSyntax;
Expand Down
43 changes: 42 additions & 1 deletion src/Compilers/CSharp/Portable/Binder/LocalBinderFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using ReferenceEqualityComparer = Roslyn.Utilities.ReferenceEqualityComparer;

namespace Microsoft.CodeAnalysis.CSharp
Expand Down Expand Up @@ -240,9 +241,49 @@ public override void VisitInvocationExpression(InvocationExpressionSyntax node)
return;
}

base.VisitInvocationExpression(node);
if (receiverIsInvocation(node, out InvocationExpressionSyntax? nested))
{
var invocations = ArrayBuilder<InvocationExpressionSyntax>.GetInstance();

invocations.Push(node);

node = nested;
while (receiverIsInvocation(node, out nested))
{
invocations.Push(node);
node = nested;
}

Visit(node.Expression);

do
{
Visit(node.ArgumentList);
}
while (invocations.TryPop(out node!));

invocations.Free();
}
else
{
Visit(node.Expression);
Visit(node.ArgumentList);
}

return;

static bool receiverIsInvocation(InvocationExpressionSyntax node, [NotNullWhen(true)] out InvocationExpressionSyntax? nested)
{
if (node.Expression is MemberAccessExpressionSyntax { Expression: InvocationExpressionSyntax receiver } && !receiver.MayBeNameofOperator())
{
nested = receiver;
return true;
}

nested = null;
return false;
}

static Symbol getAttributeTarget(Binder current)
{
Debug.Assert((current.Flags & BinderFlags.InContextualAttributeBinder) != 0);
Expand Down
12 changes: 8 additions & 4 deletions src/Compilers/CSharp/Portable/Binder/RefSafetyAnalysis.cs
Original file line number Diff line number Diff line change
Expand Up @@ -647,8 +647,6 @@ private void SetPatternLocalScopes(BoundObjectPattern pattern)

private void VisitArgumentsAndGetArgumentPlaceholders(BoundExpression? receiverOpt, ImmutableArray<BoundExpression> arguments)
{
Visit(receiverOpt);

for (int i = 0; i < arguments.Length; i++)
{
var arg = arguments[i];
Expand All @@ -663,7 +661,7 @@ private void VisitArgumentsAndGetArgumentPlaceholders(BoundExpression? receiverO
}
}

public override BoundNode? VisitCall(BoundCall node)
protected override void VisitArguments(BoundCall node)
{
VisitArgumentsAndGetArgumentPlaceholders(node.ReceiverOpt, node.Arguments);

Expand All @@ -682,7 +680,12 @@ private void VisitArgumentsAndGetArgumentPlaceholders(BoundExpression? receiverO
_diagnostics);
}

return null;
#if DEBUG
if (_visited is { } && _visited.Count <= MaxTrackVisited)
{
_visited.Add(node);
}
#endif
}

private void GetInterpolatedStringPlaceholders(
Expand Down Expand Up @@ -786,6 +789,7 @@ private void VisitObjectCreationExpressionBase(BoundObjectCreationExpressionBase

public override BoundNode? VisitIndexerAccess(BoundIndexerAccess node)
{
Visit(node.ReceiverOpt);
VisitArgumentsAndGetArgumentPlaceholders(node.ReceiverOpt, node.Arguments);

if (!node.HasErrors)
Expand Down
47 changes: 47 additions & 0 deletions src/Compilers/CSharp/Portable/BoundTree/BoundTreeWalker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,52 @@ protected BoundTreeWalkerWithStackGuardWithoutRecursionOnTheLeftOfBinaryOperator
rightOperands.Free();
return null;
}

public sealed override BoundNode? VisitCall(BoundCall node)
{
if (node.ReceiverOpt is BoundCall receiver1)
{
var calls = ArrayBuilder<BoundCall>.GetInstance();

calls.Push(node);

node = receiver1;
while (node.ReceiverOpt is BoundCall receiver2)
{
calls.Push(node);
node = receiver2;
}

VisitReceiver(node);

do
{
VisitArguments(node);
}
while (calls.TryPop(out node!));

calls.Free();
}
else
{
VisitReceiver(node);
VisitArguments(node);
}

return null;
}

/// <summary>
/// Called only for the first (in evaluation order) <see cref="BoundCall"/> in the chain.
/// </summary>
protected virtual void VisitReceiver(BoundCall node)
{
this.Visit(node.ReceiverOpt);
}

protected virtual void VisitArguments(BoundCall node)
{
this.VisitList(node.Arguments);
}
}
}
39 changes: 22 additions & 17 deletions src/Compilers/CSharp/Portable/CodeGen/EmitAddress.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,8 @@ private LocalDefinition EmitAddress(BoundExpression expression, AddressKind addr

case BoundKind.Call:
var call = (BoundCall)expression;
var methodRefKind = call.Method.RefKind;

if (methodRefKind == RefKind.Ref ||
(IsAnyReadOnly(addressKind) && methodRefKind == RefKind.RefReadOnly))
if (UseCallResultAsAddress(call, addressKind))
{
EmitCallExpression(call, UseKind.UsedAsAddress);
break;
Expand Down Expand Up @@ -178,6 +176,13 @@ private LocalDefinition EmitAddress(BoundExpression expression, AddressKind addr
return null;
}

private static bool UseCallResultAsAddress(BoundCall call, AddressKind addressKind)
{
var methodRefKind = call.Method.RefKind;
return methodRefKind == RefKind.Ref ||
(IsAnyReadOnly(addressKind) && methodRefKind == RefKind.RefReadOnly);
}

private LocalDefinition EmitPassByCopyAddress(BoundPassByCopy passByCopyExpr, AddressKind addressKind)
{
// Normally we can just defer PassByCopy to the `default`,
Expand Down Expand Up @@ -505,7 +510,7 @@ private LocalDefinition EmitReceiverRef(BoundExpression receiver, AddressKind ad
return null;
}

if (receiverType.TypeKind == TypeKind.TypeParameter)
if (BoxNonVerifierReferenceReceiver(receiverType, addressKind))
{
//[Note: Constraints on a generic parameter only restrict the types that
//the generic parameter may be instantiated with. Verification (see Partition III)
Expand All @@ -514,26 +519,26 @@ private LocalDefinition EmitReceiverRef(BoundExpression receiver, AddressKind ad
//via the generic parameter unless it is first boxed (see Partition III) or
//the callvirt instruction is prefixed with the constrained. prefix instruction
//(see Partition III). end note]
if (addressKind == AddressKind.Constrained)
EmitExpression(receiver, used: true);
// conditional receivers are already boxed if needed when pushed
if (receiver.Kind != BoundKind.ConditionalReceiver)
{
return EmitAddress(receiver, addressKind);
}
else
{
EmitExpression(receiver, used: true);
// conditional receivers are already boxed if needed when pushed
if (receiver.Kind != BoundKind.ConditionalReceiver)
{
EmitBox(receiver.Type, receiver.Syntax);
}
return null;
EmitBox(receiver.Type, receiver.Syntax);
}

return null;
}

Debug.Assert(receiverType.IsVerifierValue());
Debug.Assert(receiverType.TypeKind == TypeKind.TypeParameter || receiverType.IsValueType);
return EmitAddress(receiver, addressKind);
}

private static bool BoxNonVerifierReferenceReceiver(TypeSymbol receiverType, AddressKind addressKind)
{
Debug.Assert(!receiverType.IsVerifierReference());
return receiverType.TypeKind == TypeKind.TypeParameter && addressKind != AddressKind.Constrained;
}

/// <summary>
/// May introduce a temp which it will return. (otherwise returns null)
/// </summary>
Expand Down
Loading

0 comments on commit 53ac77e

Please sign in to comment.