Skip to content

Commit

Permalink
Merge pull request #371 from willow-ahrens/wma/fix_staging
Browse files Browse the repository at this point in the history
Wma/fix staging
  • Loading branch information
willow-ahrens authored Jan 10, 2024
2 parents 3df65ad + 3111ea4 commit a4db46f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 23 deletions.
1 change: 0 additions & 1 deletion src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ include("tensors/combinators/swizzle.jl")
include("tensors/combinators/scale.jl")
include("tensors/combinators/product.jl")


include("traits.jl")

export fsparse, fsparse!, fsprand, fspzeros, ffindnz, fread, fwrite, countstored
Expand Down
12 changes: 8 additions & 4 deletions src/interface/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ getindex_rep_def(lvl::RepeatData, idx::Drop) = SolidData(ElementData(lvl.default
getindex_rep_def(lvl::RepeatData, idx) = SolidData(ElementData(lvl.default, lvl.eltype))
getindex_rep_def(lvl::RepeatData, idx::Type{<:AbstractUnitRange}) = SolidData(ElementData(lvl.default, lvl.eltype))

Base.getindex(arr::Tensor, inds...) = getindex_helper(arr, to_indices(arr, inds)...)
@staged function getindex_helper(arr, inds...)
Base.getindex(arr::Tensor, inds...) = getindex_helper(arr, to_indices(arr, inds))
@staged function getindex_helper(arr, inds)
inds <: Type{<:Tuple}
inds = inds.parameters
@assert ndims(arr) == length(inds)
N = ndims(arr)

Expand Down Expand Up @@ -69,8 +71,10 @@ Base.getindex(arr::Tensor, inds...) = getindex_helper(arr, to_indices(arr, inds)
end
end

Base.setindex!(arr::Tensor, src, inds...) = setindex_helper(arr, src, to_indices(arr, inds)...)
@staged function setindex_helper(arr, src, inds...)
Base.setindex!(arr::Tensor, src, inds...) = setindex_helper(arr, src, to_indices(arr, inds))
@staged function setindex_helper(arr, src, inds)
inds <: Type{<:Tuple}
inds = inds.parameters
@assert ndims(arr) == length(inds)
@assert sum(ndims.(inds)) == 0 || (ndims(src) == sum(ndims.(inds)))
N = ndims(arr)
Expand Down
40 changes: 22 additions & 18 deletions src/util/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,40 +37,44 @@ ensures the first Finch invocation runs in the latest world, and leaves hooks so
that subsequent calls to [`Finch.refresh`](@ref) can update the world and
invalidate old versions. If the body contains closures, this macro uses an
eval and invokelatest strategy. Otherwise, it uses a generated function.
This macro does not support type parameters, varargs, or keyword arguments.
"""
macro staged(def)
(@capture def :function(:call(~name, ~args...), ~body)) || throw(ArgumentError("unrecognized function definition in @staged"))

called = gensym(Symbol(name, :_called))
name_2 = gensym(Symbol(name, :_eval_invokelatest))
name_3 = gensym(Symbol(name, :_evaled))
name_generator = gensym(Symbol(name, :_generator))
name_invokelatest = gensym(Symbol(name, :_invokelatest))
name_eval_invokelatest = gensym(Symbol(name, :_eval_invokelatest))

def = quote
$called = false
function $name_generator($(args...))
$body
end

function $name_invokelatest($(args...))
$(Base.invokelatest)($name_eval_invokelatest, $(args...))
end

function $name_2($(args...))
global $called
if !$called
code = let ($(args...),) = ($(map((arg)->:(typeof($arg)), args)...),)
$body
function $name_eval_invokelatest($(args...))
code = $name_generator($(map((arg)->:(typeof($arg)), args)...),)
def = quote
function $($(QuoteNode(name_invokelatest)))($($(map(arg -> :(:($($(QuoteNode(arg)))::$(typeof($arg)))), args)...)))
$($(QuoteNode(name_eval_invokelatest)))($($(map(QuoteNode, args)...)))
end
def = quote
function $($(QuoteNode(name_3)))($($(map(QuoteNode, args)...)))
$code
end
function $($(QuoteNode(name_eval_invokelatest)))($($(map(arg -> :(:($($(QuoteNode(arg)))::$(typeof($arg)))), args)...)))
$code
end
($@__MODULE__).eval(def)
$called = true
end
Base.invokelatest(($@__MODULE__).$name_3, $(args...))
($@__MODULE__).eval(def)
$(Base.invokelatest)(($@__MODULE__).$name_eval_invokelatest, $(args...))
end

@generated function $name($(args...))
# Taken from https://github.com/NHDaly/StagedFunctions.jl/blob/6fafbc560421f70b05e3df330b872877db0bf3ff/src/StagedFunctions.jl#L116
body_2 = () -> begin
code = $(body)
code = $name_generator($(args...))
if has_function_def(macroexpand($@__MODULE__, code))
:($($(name_2))($($args...)))
:($($(name_invokelatest))($($(map(QuoteNode, args)...))))
else
quote
$code
Expand Down

0 comments on commit a4db46f

Please sign in to comment.