Skip to content

Commit

Permalink
Refine LimitedAccuracy's ⊑ semantics
Browse files Browse the repository at this point in the history
As discussed in #48030, this is a different attempt to fix
the semantics of LimitedAccuracy. This fixes the same test
case as #48030, but keeps `LimitedAccuracy` ε smaller than
its wrapped lattice element. The primary change here is
that now all lattice elements that are strictly `⊑ T` are
now also `⊑ LimitedAccuracy(T)`, whereas before that was only
true for other `LimitedAccuracy` elements.

Quoting the still relevant parts of #48030's commit message:

```
I was investigating some suboptimal inference in Diffractor
(which due to its recursive structure over the order of the
taken derivative likes to tickle recursion limiting) and
noticed that inference was performing some constant propagation,
but then discarding the result. Upon further investigation,
it turned out that inference had determined the function to be
`LimitedAccuracy(...)`, but constprop found out it actually
returned `Const`. Now, ordinarily, we don't constprop functions
that inference determined to be `LimitedAccuracy`, but this
function happened to have `@constprop :aggressive` annotated.

Of course, if constprop determines that the function actually
terminates, we do want to use that information. We could hardcode
this in abstract_call_gf_by_type, but it made me take a closer look
at the lattice operations for `LimitedAccuracy`, since in theory
`abstract_call_gf_by_type` should prefer a more precise result.
```
  • Loading branch information
Keno committed Dec 31, 2022
1 parent 82fbf54 commit 0c131e9
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 36 deletions.
12 changes: 11 additions & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if const_call_result.rt ₚ rt
rt = const_call_result.rt
(; effects, const_result, edge) = const_call_result
else
add_remark!(interp, sv, "[constprop] Discarded because the result was wider than inference")
end
end
all_effects = merge_effects(all_effects, effects)
Expand Down Expand Up @@ -169,6 +171,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
this_conditional = this_const_conditional
this_rt = this_const_rt
(; effects, const_result, edge) = const_call_result
else
add_remark!(interp, sv, "[constprop] Discarded because the result was wider than inference")
end
end
all_effects = merge_effects(all_effects, effects)
Expand Down Expand Up @@ -535,6 +539,7 @@ end

const RECURSION_UNUSED_MSG = "Bounded recursion detected with unused result. Annotated return type may be wider than true result."
const RECURSION_MSG = "Bounded recursion detected. Call was widened to force convergence."
const RECURSION_MSG_HARDLIMIT = "Bounded recursion detected under hardlimit. Call was widened to force convergence."

function abstract_call_method(interp::AbstractInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, si::StmtInfo, sv::InferenceState)
if method.name === :depwarn && isdefined(Main, :Base) && method.module === Main.Base
Expand Down Expand Up @@ -573,6 +578,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
end
end
end
washardlimit = hardlimit

if topmost !== nothing
sigtuple = unwrap_unionall(sig)::DataType
Expand Down Expand Up @@ -611,7 +617,11 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
# (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases)
return MethodCallResult(Any, true, true, nothing, Effects())
end
add_remark!(interp, sv, RECURSION_MSG)
if washardlimit
add_remark!(interp, sv, RECURSION_MSG_HARDLIMIT)
else
add_remark!(interp, sv, RECURSION_MSG)
end
topmost = topmost::InferenceState
parentframe = topmost.parent
poison_callstack(sv, parentframe === nothing ? topmost : parentframe)
Expand Down
77 changes: 61 additions & 16 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,44 @@ struct StateUpdate
conditional::Bool
end

# Represent that the type estimate has been approximated, due to "causes"
# (only used in abstract interpretation, doesn't appear in optimization)
# N.B. in the lattice, this is epsilon smaller than `typ` (except Union{})
"""
struct LimitedAccuracy
A `LimitedAccuracy` lattice element is used to indicate that the true inference
result was approximate due to heuristic termination of a recursion. For example,
consider two call stacks starting from `A` and `B` that look like:
A -> C -> A -> D
B -> C -> A -> D
In the first case, inference may have decided that `A->C->A` constitutes a cycle,
widening the result it obtained for `C`, even if it might otherwise have been
able to obtain a result. In this case, the result inferred for `C` will be
annotated with this lattice type to indicate that the obtained result is an
upper bound for the non-limited inference. In particular, this means that the
call stack originating at `B` will re-perform inference without being poisoned
by the potentially inaccurate result obtained during the inference of `A`.
N.B.: We do *not* take any efforts to ensure the reverse. For example, if `B`
is inferred first, then we may cache a precise result for `C` and re-use this
result while inferring `A`, even if inference of `A` would have not been able
to obtain this result due to limiting. This is undesirable, because it makes
some inference results order dependent, but there it is unclear how this situation
could be avoided.
A `LimitedAccuracy` element wraps another lattice element (let's call it `T`)
and additionally tracks the `causes` due to which limitation occurred. As a
lattice element, `LimitedAccuracy(T)` is considered ε smaller than the
corresponding lattice element `T`, but in particular, all lattice elements that
are `⊑ T` (but not equal `T`) are also `⊑ LimitedAccuracy(T)`.
The `causes` list is used to determine whether a particular cause of limitation is
enevitable and if so, widening `LimitedAccuracy(T)` back to `T`. For example,
in the call stacks above, if any call to `A` always leads back to `A`, then
it does not matter whether we start at `A` or reach it via `B`: Any inference
that reaches `A` will always hit the same limitation and the result may thus
be cached.
"""
struct LimitedAccuracy
typ
causes::IdSet{InferenceState}
Expand All @@ -182,6 +217,7 @@ struct LimitedAccuracy
return new(typ, causes)
end
end
LimitedAccuracy(@nospecialize(T), ::Nothing) = T

"""
struct NotFound end
Expand Down Expand Up @@ -366,17 +402,22 @@ ignorelimited(typ::LimitedAccuracy) = typ.typ
# =============

function (lattice::InferenceLattice, @nospecialize(a), @nospecialize(b))
if isa(b, LimitedAccuracy)
if !isa(a, LimitedAccuracy)
return false
end
if b.causes a.causes
return false
end
b = b.typ
r = (widenlattice(lattice), ignorelimited(a), ignorelimited(b))
r || return false
isa(b, LimitedAccuracy) || return true

# We've found that ignorelimited(a) ⊑ ignorelimited(b).
# Now perform the reverse query to check for equality.
ab_eq = (widenlattice(lattice), b.typ, ignorelimited(a))

if !ab_eq
# a's unlimited type is strictly smaller than b's
return true
end
isa(a, LimitedAccuracy) && (a = a.typ)
return (widenlattice(lattice), a, b)

# a and b's unlimited types are equal.
isa(a, LimitedAccuracy) || return false # b is limited, so ε smaller
return a.causes b.causes
end

function (lattice::OptimizerLattice, @nospecialize(a), @nospecialize(b))
Expand Down Expand Up @@ -508,9 +549,13 @@ function ⊑(lattice::ConstsLattice, @nospecialize(a), @nospecialize(b))
end

function is_lattice_equal(lattice::InferenceLattice, @nospecialize(a), @nospecialize(b))
if isa(a, LimitedAccuracy) || isa(b, LimitedAccuracy)
# TODO: Unwrap these and recurse to is_lattice_equal
return (lattice, a, b) && (lattice, b, a)
if isa(a, LimitedAccuracy)
isa(b, LimitedAccuracy) || return false
a.causes == b.causes || return false
a = a.typ
b = b.typ
elseif isa(b, LimitedAccuracy)
return false
end
return is_lattice_equal(widenlattice(lattice), a, b)
end
Expand Down
94 changes: 75 additions & 19 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,6 @@ end
# A simplified type_more_complex query over the extended lattice
# (assumes typeb ⊑ typea)
function issimplertype(𝕃::AbstractLattice, @nospecialize(typea), @nospecialize(typeb))
typea = ignorelimited(typea)
typeb = ignorelimited(typeb)
typea isa MaybeUndef && (typea = typea.typ) # n.b. does not appear in inference
typeb isa MaybeUndef && (typeb = typeb.typ) # n.b. does not appear in inference
typea === typeb && return true
Expand Down Expand Up @@ -385,29 +383,87 @@ function tmerge(lattice::OptimizerLattice, @nospecialize(typea), @nospecialize(t
return tmerge(widenlattice(lattice), typea, typeb)
end

function tmerge(lattice::InferenceLattice, @nospecialize(typea), @nospecialize(typeb))
r = tmerge_fast_path(lattice, typea, typeb)
r !== nothing && return r
function union_causes(causesa::IdSet{InferenceState}, causesb::IdSet{InferenceState})
if causesa causesb
return causesb
elseif causesb causesa
return causesa
else
return union!(copy(causesa), causesb)
end
end

function merge_causes(causesa::IdSet{InferenceState}, causesb::IdSet{InferenceState})
# TODO: When lattice elements are equal, we're allowed to discard one or the
# other set, but we'll need to come up with a consistent rule. For now, we
# just check the length, but other heuristics may be applicable.
if length(causesa) < length(causesb)
return causesa
elseif length(causesb) < length(causesa)
return causesb
else
return union!(copy(causesa), causesb)
end
end

@noinline function tmerge_limited(lattice::InferenceLattice, @nospecialize(typea), @nospecialize(typeb))
typea === Union{} && return typeb
typeb === Union{} && return typea

# type-lattice for LimitedAccuracy wrapper
# the merge create a slightly narrower type than needed, but we can't
# represent the precise intersection of causes and don't attempt to
# enumerate some of these cases where we could
# Like tmerge_fast_path, but tracking which causes need to be preserved at
# the same time.
if isa(typea, LimitedAccuracy) && isa(typeb, LimitedAccuracy)
if typea.causes typeb.causes
causes = typeb.causes
elseif typeb.causes typea.causes
causes = typea.causes
causesa = typea.causes
causesb = typeb.causes
typea = typea.typ
typeb = typeb.typ
suba = (lattice, typea, typeb)
subb = (lattice, typeb, typea)

# Approximated types are lattice equal. Merge causes.
if suba && subb
causes = merge_causes(causesa, causesb)
issimplertype(lattice, typeb, typea) && return LimitedAccuracy(typeb, causesb)
elseif suba
issimplertype(lattice, typeb, typea) && return LimitedAccuracy(typeb, causesb)
causes = causesb
# `a`'s causes may be discarded
elseif subb
causes = causesa
else
causes = union!(copy(typea.causes), typeb.causes)
causes = union_causes(causesa, causesb)
end
else
if isa(typeb, LimitedAccuracy)
(typea, typeb) = (typeb, typea)
end
typea = typea::LimitedAccuracy

causes = typea.causes
typea = typea.typ

suba = (lattice, typea, typeb)
if suba
issimplertype(lattice, typeb, typea) && return typeb
# `typea` was a subtype of `b`. Whatever tmerge produces,
# we know it must be a supertype of `typeb`, so we may drop the
# causes.
causes = nothing
end
return LimitedAccuracy(tmerge(widenlattice(lattice), typea.typ, typeb.typ), causes)
elseif isa(typea, LimitedAccuracy)
return LimitedAccuracy(tmerge(widenlattice(lattice), typea.typ, typeb), typea.causes)
elseif isa(typeb, LimitedAccuracy)
return LimitedAccuracy(tmerge(widenlattice(lattice), typea, typeb.typ), typeb.causes)
subb = (lattice, typeb, typea)
end

subb && issimplertype(lattice, typea, typeb) && return LimitedAccuracy(typea, causes)
return LimitedAccuracy(tmerge(widenlattice(lattice), typea, typeb), causes)
end

function tmerge(lattice::InferenceLattice, @nospecialize(typea), @nospecialize(typeb))
if isa(typea, LimitedAccuracy) || isa(typeb, LimitedAccuracy)
return tmerge_limited(lattice, typea, typeb)
end

r = tmerge_fast_path(widenlattice(lattice), typea, typeb)
r !== nothing && return r
return tmerge(widenlattice(lattice), typea, typeb)
end

Expand Down
8 changes: 8 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4701,3 +4701,11 @@ let # jl_widen_core_extended_info
widened
end
end

# This is somewhat sensitive to the exact recursion level that inference is willing to do, but the intention
# is to test the case where inference limited a recursion, but then a forced constprop nevertheless managed
# to terminate the call.
@Base.constprop :aggressive type_level_recurse1(x...) = x[1] == 2 ? 1 : (length(x) > 100 ? x : type_level_recurse2(x[1] + 1, x..., x...))
@Base.constprop :aggressive type_level_recurse2(x...) = type_level_recurse1(x...)
type_level_recurse_entry() = Val{type_level_recurse1(1)}()
@test Base.return_types(type_level_recurse_entry, ()) |> only == Val{1}

0 comments on commit 0c131e9

Please sign in to comment.