From b668ceff6993ab8886a0128d5d527f89dd392f3f Mon Sep 17 00:00:00 2001 From: Raphael Chinchilla Date: Mon, 10 Jul 2023 16:46:16 -0700 Subject: [PATCH] Fixed minor bug in src/build_function where the keyword conv was not used --- src/build_function.jl | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index 06a2c6e40..1c4606c55 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -113,11 +113,11 @@ function _build_function(target::JuliaTarget, op, args...; expr = if cse fun = Func(dargs, [], Code.cse(unwrap(op))) (wrap_code !== nothing) && (fun = wrap_code(fun)) - toexpr(fun, states) + conv(fun, states) else fun = Func(dargs, [], op) (wrap_code !== nothing) && (fun = wrap_code(fun)) - toexpr(fun, states) + conv(fun, states) end if expression == Val{true} @@ -142,14 +142,14 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...; Symbol("ˍ₋arg$(x[1])")), enumerate([args...])) expr = if cse - toexpr(Func(dargs, [], Code.cse(unwrap(op))), states) + conv(Func(dargs, [], Code.cse(unwrap(op))), states) else - toexpr(Func(dargs, [], op), states) + conv(Func(dargs, [], op), states) end outsym = Symbol("ˍ₋out") body = inplace_expr(unwrap(op), outsym) - oop_expr = toexpr(Func([outsym, dargs...], [], body), states) + oop_expr = conv(Func([outsym, dargs...], [], body), states) N = length(shape(op)) op = unwrap(op) @@ -161,7 +161,7 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...; $outsym end) |> LiteralExpr end - ip_expr = toexpr(Func(dargs, [], op_body), states) + ip_expr = conv(Func(dargs, [], op_body), states) if expression == Val{true} oop_expr, ip_expr else @@ -194,11 +194,28 @@ function fill_array_with_zero!(x::AbstractArray) end """ + _build_function(target::JuliaTarget, rhss::AbstractArray, args...; + conv=toexpr, + expression = Val{true}, + expression_module = @__MODULE__(), + checkbounds = false, + postprocess_fbody=ex -> ex, + linenumbers = false, + outputidxs=nothing, + skipzeros = false, + force_SA = false, + wrap_code = (nothing, nothing), + fillzeros = skipzeros && !(rhss isa SparseMatrixCSC), + states = LazyState(), + iip_config = (true, true), + parallel=nothing, cse = false, kwargs...) + Build function target: `JuliaTarget` ```julia function _build_function(target::JuliaTarget, rhss, args...; - conv = toexpr, expression = Val{true}, + conv = toexpr, + expression = Val{true}, checkbounds = false, linenumbers = false, headerfun = addheader, outputidxs=nothing, @@ -250,6 +267,7 @@ Special Keyword Arguments: safety with `skipzeros`. """ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; + conv=toexpr, expression = Val{true}, expression_module = @__MODULE__(), checkbounds = false, @@ -303,10 +321,10 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; end if expression == Val{true} - return toexpr(oop_expr, states), toexpr(ip_expr, states) + return conv(oop_expr, states), conv(ip_expr, states) else - return _build_and_inject_function(expression_module, toexpr(oop_expr, states)), - _build_and_inject_function(expression_module, toexpr(ip_expr, states)) + return _build_and_inject_function(expression_module, conv(oop_expr, states)), + _build_and_inject_function(expression_module, conv(ip_expr, states)) end end