Skip to content

Commit

Permalink
optimize abstract_call_gf_by_type (#56572)
Browse files Browse the repository at this point in the history
Combines many allocations into one and types them for better inference
  • Loading branch information
aviatesk authored Nov 18, 2024
1 parent c5899c2 commit af9e6e3
Showing 1 changed file with 120 additions and 93 deletions.
213 changes: 120 additions & 93 deletions Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,76 @@ function propagate_conditional(rt::InterConditional, cond::Conditional)
return Conditional(cond.slot, new_thentype, new_elsetype)
end

function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
mutable struct SafeBox{T}
x::T
SafeBox{T}(x::T) where T = new{T}(x)
SafeBox(@nospecialize x) = new{Any}(x)
end
getindex(box::SafeBox) = box.x
setindex!(box::SafeBox{T}, x::T) where T = setfield!(box, :x, x)

struct FailedMethodMatch
reason::String
end

struct MethodMatchTarget
match::MethodMatch
edges::Vector{Union{Nothing,CodeInstance}}
edge_idx::Int
end

struct MethodMatches
applicable::Vector{MethodMatchTarget}
info::MethodMatchInfo
valid_worlds::WorldRange
end
any_ambig(result::MethodLookupResult) = result.ambig
any_ambig(info::MethodMatchInfo) = any_ambig(info.results)
any_ambig(m::MethodMatches) = any_ambig(m.info)
fully_covering(info::MethodMatchInfo) = info.fullmatch
fully_covering(m::MethodMatches) = fully_covering(m.info)

struct UnionSplitMethodMatches
applicable::Vector{MethodMatchTarget}
applicable_argtypes::Vector{Vector{Any}}
info::UnionSplitInfo
valid_worlds::WorldRange
end
any_ambig(info::UnionSplitInfo) = any(any_ambig, info.split)
any_ambig(m::UnionSplitMethodMatches) = any_ambig(m.info)
fully_covering(info::UnionSplitInfo) = all(fully_covering, info.split)
fully_covering(m::UnionSplitMethodMatches) = fully_covering(m.info)

nmatches(info::MethodMatchInfo) = length(info.results)
function nmatches(info::UnionSplitInfo)
n = 0
for mminfo in info.split
n += nmatches(mminfo)
end
return n
end

# intermediate state for computing gfresult
mutable struct CallInferenceState
inferidx::Int
rettype
exctype
all_effects::Effects
const_results::Union{Nothing,Vector{Union{Nothing,ConstResult}}} # keeps the results of inference with the extended lattice elements (if happened)
conditionals::Union{Nothing,Tuple{Vector{Any},Vector{Any}}} # keeps refinement information of call argument types when the return type is boolean
slotrefinements::Union{Nothing,Vector{Any}} # keeps refinement information on slot types obtained from call signature

# some additional fields for untyped objects (just to avoid capturing)
func
matches::Union{MethodMatches,UnionSplitMethodMatches}
function CallInferenceState(@nospecialize(func), matches::Union{MethodMatches,UnionSplitMethodMatches})
return new(#=inferidx=#1, #=rettype=#Bottom, #=exctype=#Bottom, #=all_effects=#EFFECTS_TOTAL,
#=const_results=#nothing, #=conditionals=#nothing, #=slotrefinements=#nothing,
func, matches)
end
end

function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(func),
arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype),
sv::AbsIntState, max_methods::Int)
𝕃ₚ, 𝕃ᵢ = ipo_lattice(interp), typeinf_lattice(interp)
Expand All @@ -50,12 +119,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
return Future(CallMeta(Any, Any, Effects(), NoCallInfo()))
end

(; valid_worlds, applicable, info) = matches
(; valid_worlds, applicable) = matches
update_valid_age!(sv, valid_worlds) # need to record the negative world now, since even if we don't generate any useful information, inlining might want to add an invoke edge and it won't have this information anymore
if bail_out_toplevel_call(interp, sv)
napplicable = length(applicable)
local napplicable = length(applicable)
for i = 1:napplicable
sig = applicable[i].match.spec_types
local sig = applicable[i].match.spec_types
if !isdispatchtuple(sig)
# only infer fully concrete call sites in top-level expressions (ignoring even isa_compileable_sig matches)
add_remark!(interp, sv, "Refusing to infer non-concrete call site in top-level expression")
Expand All @@ -66,26 +135,17 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),

# final result
gfresult = Future{CallMeta}()
# intermediate work for computing gfresult
rettype = exctype = Bottom
conditionals = nothing # keeps refinement information of call argument types when the return type is boolean
const_results = nothing # or const_results::Vector{Union{Nothing,ConstResult}} if any const results are available
fargs = arginfo.fargs
all_effects = EFFECTS_TOTAL
slotrefinements = nothing # keeps refinement information on slot types obtained from call signature
state = CallInferenceState(func, matches)

# split the for loop off into a function, so that we can pause and restart it at will
i::Int = 1
f = Core.Box(f)
atype = Core.Box(atype)
function infercalls(interp, sv)
local napplicable = length(applicable)
local multiple_matches = napplicable > 1
while i <= napplicable
(; match, edges, edge_idx) = applicable[i]
method = match.method
sig = match.spec_types
if bail_out_call(interp, InferenceLoopState(rettype, all_effects), sv)
while state.inferidx <= napplicable
(; match, edges, edge_idx) = applicable[state.inferidx]
local method = match.method
local sig = match.spec_types
if bail_out_call(interp, InferenceLoopState(state.rettype, state.all_effects), sv)
add_remark!(interp, sv, "Call inference reached maximally imprecise information: bailing on doing more abstract inference.")
break
end
Expand All @@ -108,10 +168,11 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
this_exct = exct
# try constant propagation with argtypes for this match
# this is in preparation for inlining, or improving the return result
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
this_arginfo = ArgInfo(fargs, this_argtypes)
local matches = state.matches
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[state.inferidx]
this_arginfo = ArgInfo(arginfo.fargs, this_argtypes)
const_call_result = abstract_call_method_with_const_args(interp,
mresult[], f.contents, this_arginfo, si, match, sv)
mresult[], state.func, this_arginfo, si, match, sv)
const_result = volatile_inf_result
if const_call_result !== nothing
this_const_conditional = ignorelimited(const_call_result.rt)
Expand Down Expand Up @@ -146,12 +207,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
end

all_effects = merge_effects(all_effects, effects)
state.all_effects = merge_effects(state.all_effects, effects)
if const_result !== nothing
local const_results = state.const_results
if const_results === nothing
const_results = fill!(Vector{Union{Nothing,ConstResult}}(undef, napplicable), nothing)
const_results = state.const_results = fill!(Vector{Union{Nothing,ConstResult}}(undef, napplicable), nothing)
end
const_results[i] = const_result
const_results[state.inferidx] = const_result
end
@assert !(this_conditional isa Conditional || this_rt isa MustAlias) "invalid lattice element returned from inter-procedural context"
if can_propagate_conditional(this_conditional, argtypes)
Expand All @@ -162,12 +224,14 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
this_rt = this_conditional
end

rettype = rettype ₚ this_rt
exctype = exctype ₚ this_exct
if has_conditional(𝕃ₚ, sv) && this_conditional !== Bottom && is_lattice_bool(𝕃ₚ, rettype) && fargs !== nothing
state.rettype = state.rettype ₚ this_rt
state.exctype = state.exctype ₚ this_exct
if has_conditional(𝕃ₚ, sv) && this_conditional !== Bottom && is_lattice_bool(𝕃ₚ, state.rettype) && arginfo.fargs !== nothing
local conditionals = state.conditionals
if conditionals === nothing
conditionals = Any[Bottom for _ in 1:length(argtypes)],
Any[Bottom for _ in 1:length(argtypes)]
conditionals = state.conditionals = (
Any[Bottom for _ in 1:length(argtypes)],
Any[Bottom for _ in 1:length(argtypes)])
end
for i = 1:length(argtypes)
cnd = conditional_argtype(𝕃ᵢ, this_conditional, match.spec_types, argtypes, i)
Expand All @@ -177,7 +241,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
edges[edge_idx] = edge

i += 1
state.inferidx += 1
return true
end # function handle1
if isready(mresult) && handle1(interp, sv)
Expand All @@ -188,30 +252,33 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
end # while

seenall = i > napplicable
seenall = state.inferidx > napplicable
retinfo = state.matches.info
if seenall # small optimization to skip some work that is already implied
local const_results = state.const_results
if const_results !== nothing
@assert napplicable == nmatches(info) == length(const_results)
info = ConstCallInfo(info, const_results)
@assert napplicable == nmatches(retinfo) == length(const_results)
retinfo = ConstCallInfo(retinfo, const_results)
end
if !fully_covering(matches) || any_ambig(matches)
if !fully_covering(state.matches) || any_ambig(state.matches)
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
all_effects = Effects(all_effects; nothrow=false)
exctype = exctype ₚ MethodError
state.all_effects = Effects(state.all_effects; nothrow=false)
state.exctype = state.exctype ₚ MethodError
end
local fargs = arginfo.fargs
if sv isa InferenceState && fargs !== nothing
slotrefinements = collect_slot_refinements(𝕃ᵢ, applicable, argtypes, fargs, sv)
state.slotrefinements = collect_slot_refinements(𝕃ᵢ, applicable, argtypes, fargs, sv)
end
rettype = from_interprocedural!(interp, rettype, sv, arginfo, conditionals)
if call_result_unused(si) && !(rettype === Bottom)
state.rettype = from_interprocedural!(interp, state.rettype, sv, arginfo, state.conditionals)
if call_result_unused(si) && !(state.rettype === Bottom)
add_remark!(interp, sv, "Call result type was widened because the return value is unused")
# We're mainly only here because the optimizer might want this code,
# but we ourselves locally don't typically care about it locally
# (beyond checking if it always throws).
# So avoid adding an edge, since we don't want to bother attempting
# to improve our result even if it does change (to always throw),
# and avoid keeping track of a more complex result type.
rettype = Any
state.rettype = Any
end
# if from_interprocedural added any pclimitations to the set inherited from the arguments,
# some of those may be part of our cycles, so those can be deleted now
Expand All @@ -230,23 +297,24 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
else
# there is unanalyzed candidate, widen type and effects to the top
rettype = exctype = Any
all_effects = Effects()
const_results = nothing
state.rettype = state.exctype = Any
state.all_effects = Effects()
state.const_results = nothing
end

# Also considering inferring the compilation signature for this method, so
# it is available to the compiler in case it ends up needing it for the invoke.
if isa(sv, InferenceState) && infer_compilation_signature(interp) && (!is_removable_if_unused(all_effects) || !call_result_unused(si))
i = 1
if (isa(sv, InferenceState) && infer_compilation_signature(interp) &&
(!is_removable_if_unused(state.all_effects) || !call_result_unused(si)))
inferidx = SafeBox{Int}(1)
function infercalls2(interp, sv)
local napplicable = length(applicable)
local multiple_matches = napplicable > 1
while i <= napplicable
(; match, edges, edge_idx) = applicable[i]
i += 1
method = match.method
sig = match.spec_types
while inferidx[] <= napplicable
(; match, edges, edge_idx) = applicable[inferidx[]]
inferidx[] += 1
local method = match.method
local sig = match.spec_types
mi = specialize_method(match; preexisting=true)
if mi === nothing || !const_prop_methodinstance_heuristic(interp, mi, arginfo, sv)
csig = get_compileable_sig(method, sig, match.sparams)
Expand All @@ -265,55 +333,14 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
infercalls2(interp, sv) || push!(sv.tasks, infercalls2)
end

gfresult[] = CallMeta(rettype, exctype, all_effects, info, slotrefinements)
gfresult[] = CallMeta(state.rettype, state.exctype, state.all_effects, retinfo, state.slotrefinements)
return true
end # function infercalls
# start making progress on the first call
infercalls(interp, sv) || push!(sv.tasks, infercalls)
return gfresult
end

struct FailedMethodMatch
reason::String
end

struct MethodMatchTarget
match::MethodMatch
edges::Vector{Union{Nothing,CodeInstance}}
edge_idx::Int
end

struct MethodMatches
applicable::Vector{MethodMatchTarget}
info::MethodMatchInfo
valid_worlds::WorldRange
end
any_ambig(result::MethodLookupResult) = result.ambig
any_ambig(info::MethodMatchInfo) = any_ambig(info.results)
any_ambig(m::MethodMatches) = any_ambig(m.info)
fully_covering(info::MethodMatchInfo) = info.fullmatch
fully_covering(m::MethodMatches) = fully_covering(m.info)

struct UnionSplitMethodMatches
applicable::Vector{MethodMatchTarget}
applicable_argtypes::Vector{Vector{Any}}
info::UnionSplitInfo
valid_worlds::WorldRange
end
any_ambig(info::UnionSplitInfo) = any(any_ambig, info.split)
any_ambig(m::UnionSplitMethodMatches) = any_ambig(m.info)
fully_covering(info::UnionSplitInfo) = all(fully_covering, info.split)
fully_covering(m::UnionSplitMethodMatches) = fully_covering(m.info)

nmatches(info::MethodMatchInfo) = length(info.results)
function nmatches(info::UnionSplitInfo)
n = 0
for mminfo in info.split
n += nmatches(mminfo)
end
return n
end

function find_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any}, @nospecialize(atype);
max_union_splitting::Int = InferenceParams(interp).max_union_splitting,
max_methods::Int = InferenceParams(interp).max_methods)
Expand Down

2 comments on commit af9e6e3

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The package evaluation job you requested has completed - possible new issues were detected.
The full report is available.

Please sign in to comment.