Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine LimitedAccuracy's ⊑ semantics #48045

Merged
merged 2 commits into from
Jan 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if const_call_result.rt ⊑ₚ rt
rt = const_call_result.rt
(; effects, const_result, edge) = const_call_result
else
add_remark!(interp, sv, "[constprop] Discarded because the result was wider than inference")
end
end
all_effects = merge_effects(all_effects, effects)
Expand Down Expand Up @@ -169,6 +171,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
this_conditional = this_const_conditional
this_rt = this_const_rt
(; effects, const_result, edge) = const_call_result
else
add_remark!(interp, sv, "[constprop] Discarded because the result was wider than inference")
end
end
all_effects = merge_effects(all_effects, effects)
Expand Down Expand Up @@ -535,6 +539,7 @@ end

const RECURSION_UNUSED_MSG = "Bounded recursion detected with unused result. Annotated return type may be wider than true result."
const RECURSION_MSG = "Bounded recursion detected. Call was widened to force convergence."
const RECURSION_MSG_HARDLIMIT = "Bounded recursion detected under hardlimit. Call was widened to force convergence."

function abstract_call_method(interp::AbstractInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, si::StmtInfo, sv::InferenceState)
if method.name === :depwarn && isdefined(Main, :Base) && method.module === Main.Base
Expand Down Expand Up @@ -573,6 +578,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
end
end
end
washardlimit = hardlimit

if topmost !== nothing
sigtuple = unwrap_unionall(sig)::DataType
Expand Down Expand Up @@ -611,7 +617,11 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
# (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases)
return MethodCallResult(Any, true, true, nothing, Effects())
end
add_remark!(interp, sv, RECURSION_MSG)
if washardlimit
add_remark!(interp, sv, RECURSION_MSG_HARDLIMIT)
else
add_remark!(interp, sv, RECURSION_MSG)
end
Keno marked this conversation as resolved.
Show resolved Hide resolved
topmost = topmost::InferenceState
parentframe = topmost.parent
poison_callstack(sv, parentframe === nothing ? topmost : parentframe)
Expand Down
77 changes: 61 additions & 16 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,44 @@ struct StateUpdate
conditional::Bool
end

# Represent that the type estimate has been approximated, due to "causes"
# (only used in abstract interpretation, doesn't appear in optimization)
# N.B. in the lattice, this is epsilon smaller than `typ` (except Union{})
"""
struct LimitedAccuracy

A `LimitedAccuracy` lattice element is used to indicate that the true inference
result was approximate due to heuristic termination of a recursion. For example,
consider two call stacks starting from `A` and `B` that look like:

A -> C -> A -> D
B -> C -> A -> D

In the first case, inference may have decided that `A->C->A` constitutes a cycle,
widening the result it obtained for `C`, even if it might otherwise have been
able to obtain a result. In this case, the result inferred for `C` will be
annotated with this lattice type to indicate that the obtained result is an
upper bound for the non-limited inference. In particular, this means that the
call stack originating at `B` will re-perform inference without being poisoned
by the potentially inaccurate result obtained during the inference of `A`.

N.B.: We do *not* take any efforts to ensure the reverse. For example, if `B`
is inferred first, then we may cache a precise result for `C` and re-use this
result while inferring `A`, even if inference of `A` would have not been able
to obtain this result due to limiting. This is undesirable, because it makes
some inference results order dependent, but there it is unclear how this situation
could be avoided.

A `LimitedAccuracy` element wraps another lattice element (let's call it `T`)
and additionally tracks the `causes` due to which limitation occurred. As a
lattice element, `LimitedAccuracy(T)` is considered ε smaller than the
corresponding lattice element `T`, but in particular, all lattice elements that
are `⊑ T` (but not equal `T`) are also `⊑ LimitedAccuracy(T)`.

The `causes` list is used to determine whether a particular cause of limitation is
enevitable and if so, widening `LimitedAccuracy(T)` back to `T`. For example,
Keno marked this conversation as resolved.
Show resolved Hide resolved
in the call stacks above, if any call to `A` always leads back to `A`, then
it does not matter whether we start at `A` or reach it via `B`: Any inference
that reaches `A` will always hit the same limitation and the result may thus
be cached.
"""
struct LimitedAccuracy
typ
causes::IdSet{InferenceState}
Expand All @@ -182,6 +217,7 @@ struct LimitedAccuracy
return new(typ, causes)
end
end
LimitedAccuracy(@nospecialize(T), ::Nothing) = T

"""
struct NotFound end
Expand Down Expand Up @@ -366,17 +402,22 @@ ignorelimited(typ::LimitedAccuracy) = typ.typ
# =============

function ⊑(lattice::InferenceLattice, @nospecialize(a), @nospecialize(b))
if isa(b, LimitedAccuracy)
if !isa(a, LimitedAccuracy)
return false
end
if b.causes ⊈ a.causes
return false
end
b = b.typ
r = ⊑(widenlattice(lattice), ignorelimited(a), ignorelimited(b))
r || return false
isa(b, LimitedAccuracy) || return true

# We've found that ignorelimited(a) ⊑ ignorelimited(b).
# Now perform the reverse query to check for equality.
ab_eq = ⊑(widenlattice(lattice), b.typ, ignorelimited(a))

if !ab_eq
# a's unlimited type is strictly smaller than b's
return true
end
isa(a, LimitedAccuracy) && (a = a.typ)
return ⊑(widenlattice(lattice), a, b)

# a and b's unlimited types are equal.
isa(a, LimitedAccuracy) || return false # b is limited, so ε smaller
return a.causes ⊆ b.causes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old code used to have the inverse check here

Suggested change
return a.causes b.causes
return b.causes a.causes

Was this supposed to get switched here? I think that is incorrect now, as the wider result should have fewer causes, but this does the opposite.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me double check. I don't think I intended to switch it - it's possible I got my cases mixed up.

end

function ⊑(lattice::OptimizerLattice, @nospecialize(a), @nospecialize(b))
Expand Down Expand Up @@ -508,9 +549,13 @@ function ⊑(lattice::ConstsLattice, @nospecialize(a), @nospecialize(b))
end

function is_lattice_equal(lattice::InferenceLattice, @nospecialize(a), @nospecialize(b))
if isa(a, LimitedAccuracy) || isa(b, LimitedAccuracy)
# TODO: Unwrap these and recurse to is_lattice_equal
return ⊑(lattice, a, b) && ⊑(lattice, b, a)
if isa(a, LimitedAccuracy)
isa(b, LimitedAccuracy) || return false
a.causes == b.causes || return false
a = a.typ
b = b.typ
elseif isa(b, LimitedAccuracy)
return false
end
return is_lattice_equal(widenlattice(lattice), a, b)
end
Expand Down
94 changes: 75 additions & 19 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,6 @@ end
# A simplified type_more_complex query over the extended lattice
# (assumes typeb ⊑ typea)
function issimplertype(𝕃::AbstractLattice, @nospecialize(typea), @nospecialize(typeb))
typea = ignorelimited(typea)
typeb = ignorelimited(typeb)
typea isa MaybeUndef && (typea = typea.typ) # n.b. does not appear in inference
typeb isa MaybeUndef && (typeb = typeb.typ) # n.b. does not appear in inference
typea === typeb && return true
Expand Down Expand Up @@ -385,29 +383,87 @@ function tmerge(lattice::OptimizerLattice, @nospecialize(typea), @nospecialize(t
return tmerge(widenlattice(lattice), typea, typeb)
end

function tmerge(lattice::InferenceLattice, @nospecialize(typea), @nospecialize(typeb))
r = tmerge_fast_path(lattice, typea, typeb)
r !== nothing && return r
function union_causes(causesa::IdSet{InferenceState}, causesb::IdSet{InferenceState})
if causesa ⊆ causesb
return causesb
elseif causesb ⊆ causesa
return causesa
else
return union!(copy(causesa), causesb)
end
end

function merge_causes(causesa::IdSet{InferenceState}, causesb::IdSet{InferenceState})
# TODO: When lattice elements are equal, we're allowed to discard one or the
# other set, but we'll need to come up with a consistent rule. For now, we
# just check the length, but other heuristics may be applicable.
if length(causesa) < length(causesb)
return causesa
elseif length(causesb) < length(causesa)
return causesb
else
return union!(copy(causesa), causesb)
end
end

@noinline function tmerge_limited(lattice::InferenceLattice, @nospecialize(typea), @nospecialize(typeb))
typea === Union{} && return typeb
typeb === Union{} && return typea

# type-lattice for LimitedAccuracy wrapper
# the merge create a slightly narrower type than needed, but we can't
# represent the precise intersection of causes and don't attempt to
# enumerate some of these cases where we could
# Like tmerge_fast_path, but tracking which causes need to be preserved at
# the same time.
if isa(typea, LimitedAccuracy) && isa(typeb, LimitedAccuracy)
if typea.causes ⊆ typeb.causes
causes = typeb.causes
elseif typeb.causes ⊆ typea.causes
causes = typea.causes
causesa = typea.causes
causesb = typeb.causes
typea = typea.typ
typeb = typeb.typ
suba = ⊑(lattice, typea, typeb)
subb = ⊑(lattice, typeb, typea)

# Approximated types are lattice equal. Merge causes.
if suba && subb
causes = merge_causes(causesa, causesb)
issimplertype(lattice, typeb, typea) && return LimitedAccuracy(typeb, causesb)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
issimplertype(lattice, typeb, typea) && return LimitedAccuracy(typeb, causesb)
issimplertype(lattice, typeb, typea) && return LimitedAccuracy(typeb, causes)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think symmetry and equality demands that you are not supposed to check issimplertype in this case (see tmerge_fast_path)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we are still changing the lattice element here potentially if the causes are different, but yes fair enough, I'll drop the issimplertype check.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point though. Is this new tmerge algorithm monotonic and convergent? In particular, this merge call here (with the heuristic towards shorter lengths but calling union if undecidable) seems like it might make inference non-terminating as it repeatedly merged and dropped causes lists here in a loop

elseif suba
issimplertype(lattice, typeb, typea) && return LimitedAccuracy(typeb, causesb)
causes = causesb
# `a`'s causes may be discarded
elseif subb
causes = causesa
else
causes = union!(copy(typea.causes), typeb.causes)
causes = union_causes(causesa, causesb)
end
else
if isa(typeb, LimitedAccuracy)
(typea, typeb) = (typeb, typea)
end
vtjnash marked this conversation as resolved.
Show resolved Hide resolved
typea = typea::LimitedAccuracy

causes = typea.causes
typea = typea.typ

suba = ⊑(lattice, typea, typeb)
if suba
issimplertype(lattice, typeb, typea) && return typeb
# `typea` was a subtype of `b`. Whatever tmerge produces,
# we know it must be a supertype of `typeb`, so we may drop the
Keno marked this conversation as resolved.
Show resolved Hide resolved
# causes.
causes = nothing
end
return LimitedAccuracy(tmerge(widenlattice(lattice), typea.typ, typeb.typ), causes)
elseif isa(typea, LimitedAccuracy)
return LimitedAccuracy(tmerge(widenlattice(lattice), typea.typ, typeb), typea.causes)
elseif isa(typeb, LimitedAccuracy)
return LimitedAccuracy(tmerge(widenlattice(lattice), typea, typeb.typ), typeb.causes)
subb = ⊑(lattice, typeb, typea)
end

subb && issimplertype(lattice, typea, typeb) && return LimitedAccuracy(typea, causes)
return LimitedAccuracy(tmerge(widenlattice(lattice), typea, typeb), causes)
end

function tmerge(lattice::InferenceLattice, @nospecialize(typea), @nospecialize(typeb))
if isa(typea, LimitedAccuracy) || isa(typeb, LimitedAccuracy)
return tmerge_limited(lattice, typea, typeb)
end

r = tmerge_fast_path(widenlattice(lattice), typea, typeb)
r !== nothing && return r
return tmerge(widenlattice(lattice), typea, typeb)
end

Expand Down
8 changes: 8 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4701,3 +4701,11 @@ let # jl_widen_core_extended_info
widened
end
end

# This is somewhat sensitive to the exact recursion level that inference is willing to do, but the intention
# is to test the case where inference limited a recursion, but then a forced constprop nevertheless managed
# to terminate the call.
@Base.constprop :aggressive type_level_recurse1(x...) = x[1] == 2 ? 1 : (length(x) > 100 ? x : type_level_recurse2(x[1] + 1, x..., x...))
@Base.constprop :aggressive type_level_recurse2(x...) = type_level_recurse1(x...)
type_level_recurse_entry() = Val{type_level_recurse1(1)}()
@test Base.return_types(type_level_recurse_entry, ()) |> only == Val{1}