Skip to content


Refine LimitedAccuracy's ⊑ semantics
Browse files Browse the repository at this point in the history
As discussed in #48030, this is a different attempt to fix
the semantics of LimitedAccuracy. This fixes the same test
case as #48030, but keeps `LimitedAccuracy` ε smaller than
its wrapped lattice element. The primary change here is
that now all lattice elements that are strictly `⊑ T` are
now also `⊑ LimitedAccuracy(T)`, whereas before that was only
true for other `LimitedAccuracy` elements.

Quoting the still relevant parts of #48030's commit message:

I was investigating some suboptimal inference in Diffractor
(which due to its recursive structure over the order of the
taken derivative likes to tickle recursion limiting) and
noticed that inference was performing some constant propagation,
but then discarding the result. Upon further investigation,
it turned out that inference had determined the function to be
`LimitedAccuracy(...)`, but constprop found out it actually
returned `Const`. Now, ordinarily, we don't constprop functions
that inference determined to be `LimitedAccuracy`, but this
function happened to have `@constprop :aggressive` annotated.

Of course, if constprop determines that the function actually
terminates, we do want to use that information. We could hardcode
this in abstract_call_gf_by_type, but it made me take a closer look
at the lattice operations for `LimitedAccuracy`, since in theory
`abstract_call_gf_by_type` should prefer a more precise result.
  • Loading branch information
Keno committed Dec 30, 2022
1 parent 82fbf54 commit be9bd9a
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 25 deletions.
12 changes: 11 additions & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if const_call_result.rt ₚ rt
rt = const_call_result.rt
(; effects, const_result, edge) = const_call_result
add_remark!(interp, sv, "[constprop] Discarded because the result was wider than inference")
all_effects = merge_effects(all_effects, effects)
Expand Down Expand Up @@ -169,6 +171,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
this_conditional = this_const_conditional
this_rt = this_const_rt
(; effects, const_result, edge) = const_call_result
add_remark!(interp, sv, "[constprop] Discarded because the result was wider than inference")
all_effects = merge_effects(all_effects, effects)
Expand Down Expand Up @@ -535,6 +539,7 @@ end

const RECURSION_UNUSED_MSG = "Bounded recursion detected with unused result. Annotated return type may be wider than true result."
const RECURSION_MSG = "Bounded recursion detected. Call was widened to force convergence."
const RECURSION_MSG_HARDLIMIT = "Bounded recursion detected under hardlimit. Call was widened to force convergence."

function abstract_call_method(interp::AbstractInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, si::StmtInfo, sv::InferenceState)
if === :depwarn && isdefined(Main, :Base) && method.module === Main.Base
Expand Down Expand Up @@ -573,6 +578,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
washardlimit = hardlimit

if topmost !== nothing
sigtuple = unwrap_unionall(sig)::DataType
Expand Down Expand Up @@ -611,7 +617,11 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
# (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases)
return MethodCallResult(Any, true, true, nothing, Effects())
add_remark!(interp, sv, RECURSION_MSG)
if washardlimit
add_remark!(interp, sv, RECURSION_MSG_HARDLIMIT)
add_remark!(interp, sv, RECURSION_MSG)
topmost = topmost::InferenceState
parentframe = topmost.parent
poison_callstack(sv, parentframe === nothing ? topmost : parentframe)
Expand Down
55 changes: 48 additions & 7 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,44 @@ struct StateUpdate

# Represent that the type estimate has been approximated, due to "causes"
# (only used in abstract interpretation, doesn't appear in optimization)
# N.B. in the lattice, this is epsilon smaller than `typ` (except Union{})
struct LimitedAccuracy
A `LimitedAccuracy` lattice element is used to indicate that the true inference
result was approximate due to heuristic termination of a recursion. For example,
consider two call stacks starting from `A` and `B` that look like:
A -> C -> A -> D
B -> C -> A -> D
In the first case, inference may have decided that `A->C->A` constitutes a cycle,
widening the result it obtained for `C`, even if it might otherwise have been
able to obtain a result. In this case, the result inferred for `C` will be
annotated with this lattice type to indicate that the obtained result is an
upper bound for the non-limited inference. In particular, this means that the
call stack originating at `B` will re-perform inference without being poisoned
by the potentially inaccurate result obtained during the inference of `A`.
N.B.: We do *not* take any efforts to ensure the reverse. For example, if `B`
is inferred first, then we may cache a precise result for `C` and re-use this
result while inferring `A`, even if inference of `A` would have not been able
to obtain this result due to limiting. This is undesirable, because it makes
some inference results order dependent, but there it is unclear how this situation
could be avoided.
A `LimitedAccuracy` element wraps another lattice element (let's call it `T`)
and additionally tracks the `causes` due to which limitation occurred. As a
lattice element, `LimitedAccuracy(T)` is considered ε smaller than the
corresponding lattice element `T`, but in particular, all lattice elements that
are `⊑ T` (but not equal `T`) are also `⊑ LimitedAccuracy(T)`.
The `causes` list is used to determine whether a particular cause of limitation is
enevitable and if so, widening `LimitedAccuracy(T)` back to `T`. For example,
in the call stacks above, if any call to `A` always leads back to `A`, then
it does not matter whether we start at `A` or reach it via `B`: Any inference
that reaches `A` will always hit the same limitation and the result may thus
be cached.
struct LimitedAccuracy
Expand All @@ -182,6 +217,7 @@ struct LimitedAccuracy
return new(typ, causes)
LimitedAccuracy(@nospecialize(T), ::Nothing) = T

struct NotFound end
Expand Down Expand Up @@ -368,7 +404,8 @@ ignorelimited(typ::LimitedAccuracy) = typ.typ
function (lattice::InferenceLattice, @nospecialize(a), @nospecialize(b))
if isa(b, LimitedAccuracy)
if !isa(a, LimitedAccuracy)
return false
# epsilon smaller
return (widenlattice(lattice), a, b.typ) && !(widenlattice(lattice), b.typ, a)
if b.causes a.causes
return false
Expand Down Expand Up @@ -508,9 +545,13 @@ function ⊑(lattice::ConstsLattice, @nospecialize(a), @nospecialize(b))

function is_lattice_equal(lattice::InferenceLattice, @nospecialize(a), @nospecialize(b))
if isa(a, LimitedAccuracy) || isa(b, LimitedAccuracy)
# TODO: Unwrap these and recurse to is_lattice_equal
return (lattice, a, b) && (lattice, b, a)
if isa(a, LimitedAccuracy)
isa(b, LimitedAccuracy) || return false
a.causes == b.causes || return false
a = a.typ
b = b.typ
elseif isa(b, LimitedAccuracy)
return false
return is_lattice_equal(widenlattice(lattice), a, b)
Expand Down
91 changes: 74 additions & 17 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,29 +385,86 @@ function tmerge(lattice::OptimizerLattice, @nospecialize(typea), @nospecialize(t
return tmerge(widenlattice(lattice), typea, typeb)

function tmerge(lattice::InferenceLattice, @nospecialize(typea), @nospecialize(typeb))
r = tmerge_fast_path(lattice, typea, typeb)
r !== nothing && return r
function union_causes(causesa::IdSet{InferenceState}, causesb::IdSet{InferenceState})
if causesa causesb
return causesb
elseif causesb causesa
return causesa
return union!(copy(causesa), causesb)

function merge_causes(causesa::IdSet{InferenceState}, causesb::IdSet{InferenceState})
# TODO: When lattice elements are equal, we're allowed to discard one or the
# other set, but we'll need to come up with a consistent rule. For now, we
# just check the length, but other heuristics may be applicable.
if length(causesa) < length(causesb)
return causesa
elseif length(causesb) < length(causesa)
return causesb
return union!(copy(causesa), causesb)

@noinline function tmerge_limited(lattice::InferenceLattice, @nospecialize(typea), @nospecialize(typeb))
typea === Union{} && return typeb
typeb === Union{} && return typea

# type-lattice for LimitedAccuracy wrapper
# the merge create a slightly narrower type than needed, but we can't
# represent the precise intersection of causes and don't attempt to
# enumerate some of these cases where we could
# Like tmerge_fast_path, but tracking which causes need to be preserved at
# the same time.
if isa(typea, LimitedAccuracy) && isa(typeb, LimitedAccuracy)
if typea.causes typeb.causes
causes = typeb.causes
elseif typeb.causes typea.causes
causes = typea.causes
causesa = typea.causes
causesb = typeb.causes
typea = typea.typ
typeb = typeb.typ
suba = (lattice, typea, typeb)
subb = (lattice, typeb, typea)

# Approximated types are lattice equal. Merge causes.
if suba && subb
causes = merge_causes(causesa, causesb)
issimplertype(lattice, typeb, typea) && return LimitedAccuracy(typeb, causesb)
elseif suba
issimplertype(lattice, typeb, typea) && return LimitedAccuracy(typeb, causesb)
causes = causesnb
# `a`'s causes may be discarded
elseif subb
causes = causesb
causes = union!(copy(typea.causes), typeb.causes)
causes = union_causes(causesa, causesb)
if isa(typeb, LimitedAccuracy)
(typea, typeb) = (typeb, typea)

causes = typea.causes
typea = typea.typ

suba = (lattice, typea, typeb)
if suba
issimplertype(lattice, typeb, typea) && return typeb
# `typea` was a subtype of `b`. Whatever tmerge produces,
# we know it must be a supertype of `typeb`, so we may drop the
# causes.
causes = nothing
return LimitedAccuracy(tmerge(widenlattice(lattice), typea.typ, typeb.typ), causes)
elseif isa(typea, LimitedAccuracy)
return LimitedAccuracy(tmerge(widenlattice(lattice), typea.typ, typeb), typea.causes)
elseif isa(typeb, LimitedAccuracy)
return LimitedAccuracy(tmerge(widenlattice(lattice), typea, typeb.typ), typeb.causes)
subb = (lattice, typeb, typea)

subb && issimplertype(lattice, typea, typeb) && return LimitedAccuracy(typea, causes)
return LimitedAccuracy(tmerge(widenlattice(lattice), typea, typeb), causes)

function tmerge(lattice::InferenceLattice, @nospecialize(typea), @nospecialize(typeb))
if isa(typea, LimitedAccuracy) || isa(typeb, LimitedAccuracy)
return tmerge_limited(lattice, typea, typeb)

r = tmerge_fast_path(widenlattice(lattice), typea, typeb)
r !== nothing && return r
return tmerge(widenlattice(lattice), typea, typeb)

Expand Down
8 changes: 8 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4701,3 +4701,11 @@ let # jl_widen_core_extended_info

# This is somewhat sensitive to the exact recursion level that inference is willing to do, but the intention
# is to test the case where inference limited a recursion, but then a forced constprop nevertheless managed
# to terminate the call.
@Base.constprop :aggressive type_level_recurse1(x...) = x[1] == 2 ? 1 : (length(x) > 100 ? x : type_level_recurse2(x[1] + 1, x..., x...))
@Base.constprop :aggressive type_level_recurse2(x...) = type_level_recurse1(x...)
type_level_recurse_entry() = Val{type_level_recurse1(1)}()
@test Base.return_types(type_level_recurse_entry, ()) |> only == Val{1}

0 comments on commit be9bd9a

Please sign in to comment.