Skip to content

Commit

Permalink
Fixed minor bug in src/build_function where the keyword conv was not …
Browse files Browse the repository at this point in the history
…used
  • Loading branch information
raphaelchinchilla committed Jul 10, 2023
1 parent 2f2131d commit b668cef
Showing 1 changed file with 28 additions and 10 deletions.
38 changes: 28 additions & 10 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ function _build_function(target::JuliaTarget, op, args...;
expr = if cse
fun = Func(dargs, [], Code.cse(unwrap(op)))
(wrap_code !== nothing) && (fun = wrap_code(fun))
toexpr(fun, states)
conv(fun, states)
else
fun = Func(dargs, [], op)
(wrap_code !== nothing) && (fun = wrap_code(fun))
toexpr(fun, states)
conv(fun, states)
end

if expression == Val{true}
Expand All @@ -142,14 +142,14 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;
Symbol("ˍ₋arg$(x[1])")), enumerate([args...]))

expr = if cse
toexpr(Func(dargs, [], Code.cse(unwrap(op))), states)
conv(Func(dargs, [], Code.cse(unwrap(op))), states)
else
toexpr(Func(dargs, [], op), states)
conv(Func(dargs, [], op), states)
end

outsym = Symbol("ˍ₋out")
body = inplace_expr(unwrap(op), outsym)
oop_expr = toexpr(Func([outsym, dargs...], [], body), states)
oop_expr = conv(Func([outsym, dargs...], [], body), states)

N = length(shape(op))
op = unwrap(op)
Expand All @@ -161,7 +161,7 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;
$outsym
end) |> LiteralExpr
end
ip_expr = toexpr(Func(dargs, [], op_body), states)
ip_expr = conv(Func(dargs, [], op_body), states)
if expression == Val{true}
oop_expr, ip_expr
else
Expand Down Expand Up @@ -194,11 +194,28 @@ function fill_array_with_zero!(x::AbstractArray)
end

"""
_build_function(target::JuliaTarget, rhss::AbstractArray, args...;
conv=toexpr,
expression = Val{true},
expression_module = @__MODULE__(),
checkbounds = false,
postprocess_fbody=ex -> ex,
linenumbers = false,
outputidxs=nothing,
skipzeros = false,
force_SA = false,
wrap_code = (nothing, nothing),
fillzeros = skipzeros && !(rhss isa SparseMatrixCSC),
states = LazyState(),
iip_config = (true, true),
parallel=nothing, cse = false, kwargs...)
Build function target: `JuliaTarget`
```julia
function _build_function(target::JuliaTarget, rhss, args...;
conv = toexpr, expression = Val{true},
conv = toexpr,
expression = Val{true},
checkbounds = false,
linenumbers = false,
headerfun = addheader, outputidxs=nothing,
Expand Down Expand Up @@ -250,6 +267,7 @@ Special Keyword Arguments:
safety with `skipzeros`.
"""
function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
conv=toexpr,
expression = Val{true},
expression_module = @__MODULE__(),
checkbounds = false,
Expand Down Expand Up @@ -303,10 +321,10 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
end

if expression == Val{true}
return toexpr(oop_expr, states), toexpr(ip_expr, states)
return conv(oop_expr, states), conv(ip_expr, states)
else
return _build_and_inject_function(expression_module, toexpr(oop_expr, states)),
_build_and_inject_function(expression_module, toexpr(ip_expr, states))
return _build_and_inject_function(expression_module, conv(oop_expr, states)),
_build_and_inject_function(expression_module, conv(ip_expr, states))
end
end

Expand Down

0 comments on commit b668cef

Please sign in to comment.