diff --git a/src/Chain.jl b/src/Chain.jl index 600e58d..e9dd9a0 100644 --- a/src/Chain.jl +++ b/src/Chain.jl @@ -10,19 +10,25 @@ insert_first_arg(symbol::Symbol, firstarg) = Expr(:call, symbol, firstarg) insert_first_arg(any, firstarg) = error("Can't insert an argument to $any. Needs to be a Symbol or a call expression") function insert_first_arg(e::Expr, firstarg) + head = e.head + args = e.args # f(a, b) --> f(firstarg, a, b) - if e.head == :call && length(e.args) > 1 - Expr(e.head, e.args[1], firstarg, e.args[2:end]...) + if head == :call && length(args) > 1 + if length(args) ≥ 2 && Meta.isexpr(args[2], :parameters) + Expr(head, args[1:2]..., firstarg, args[3:end]...) + else + Expr(head, args[1], firstarg, args[2:end]...) + end # @. somesymbol --> somesymbol.(firstarg) - elseif e.head == :macrocall && length(e.args) == 3 && e.args[1] == Symbol("@__dot__") && - e.args[2] isa LineNumberNode && e.args[3] isa Symbol - Expr(:., e.args[3], Expr(:tuple, firstarg)) + elseif head == :macrocall && length(args) == 3 && args[1] == Symbol("@__dot__") && + args[2] isa LineNumberNode && args[3] isa Symbol + Expr(:., args[3], Expr(:tuple, firstarg)) # @macro(a, b) --> @macro(firstarg, a, b) - elseif e.head == :macrocall && e.args[1] isa Symbol && e.args[2] isa LineNumberNode - Expr(e.head, e.args[1], e.args[2], firstarg, e.args[3:end]...) + elseif head == :macrocall && args[1] isa Symbol && args[2] isa LineNumberNode + Expr(head, args[1], args[2], firstarg, args[3:end]...) else error("Can't prepend first arg to expression $e that isn't a call.") @@ -45,7 +51,7 @@ function rewrite(expr, replacement) replacement = gensym() new_expr = Expr(Symbol("="), replacement, new_expr) end - + (new_expr, replacement) end diff --git a/test/runtests.jl b/test/runtests.jl index aa15d4a..37c8fb4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -133,7 +133,7 @@ macro broadcastminus(exp1, exp2) end @testset "splicing into macro calls" begin - + x = 1 y = @chain x begin @sin @@ -208,4 +208,16 @@ end sum(_) end end) -end \ No newline at end of file +end + +@testset "handling keyword argments" begin + f(a; kwarg) = (a, kwarg) + @test (:a, :kwarg) == @chain begin + :a + f(kwarg = :kwarg) + end + @test (:a, :kwarg) == @chain begin + :a + f(; kwarg = :kwarg) + end +end