diff --git a/src/Chain.jl b/src/Chain.jl index be144fa..d22259a 100644 --- a/src/Chain.jl +++ b/src/Chain.jl @@ -110,6 +110,16 @@ function rewrite(expr, replacement) (new_expr, replacement) end +is_empty_do(x) = (x isa Expr && + x.head == :do && + length(x.args) == 2 && + x.args[2] isa Expr && + x.args[2].head == :-> && + x.args[2].args[1].head == :tuple && + length(x.args[2].args[1].args) == 0) + + + rewrite(l::LineNumberNode, replacement) = (l, replacement) function rewrite_chain_block(firstpart, block) @@ -144,13 +154,15 @@ macro chain(initial_value, block::Expr) if !(block.head == :block) block = Expr(:block, block) end - rewrite_chain_block(initial_value, block) + esc(rewrite_chain_block(initial_value, block)) end + + """ @chain(initial_value, args...) -Rewrites a series of argments, either expressions or symbols, to feed the result +Rewrites a series of arguments, either expressions or symbols, to feed the result of each line into the next one. The initial value is given by the first argument. In all arguments, underscores are replaced by the argument's result. @@ -168,7 +180,7 @@ x == sum(sqrt.(filter(!=(2), [1, 2, 3]))) ``` """ macro chain(initial_value, args...) - rewrite_chain_block(initial_value, Expr(:block, args...)) + esc(rewrite_chain_block(initial_value, Expr(:block, args...))) end function rewrite_chain_block(block) @@ -199,8 +211,6 @@ function rewrite_chain_block(block) end result = Expr(:let, Expr(:block), Expr(:block, rewritten_exprs..., replacement)) - - :($(esc(result))) end # if a line in a chain is a string, it can be parsed as a docstring @@ -248,24 +258,65 @@ x == sum(sqrt.(filter(!=(2), [1, 2, 3]))) ``` """ macro chain(block::Expr) - rewrite_chain_block(block) + if is_empty_do(block) + esc(rewrite_do_block_chain(block)) + else + esc(rewrite_chain_block(block)) + end end +function rewrite_do_block_chain(doblock; escape = true) + call = doblock.args[1] + if !(call isa Expr && call.head == :call) + error("Expected a call expression in do syntax") + end + var = gensym() + block = doblock.args[2].args[2] + if !(block isa Expr && block.head == :block) + error("Expected a block in do syntax") + end + chain_block_rewritten = rewrite_chain_block(var, block) + # dump(chain_block_rewritten) + + Expr( + :do, + call, + Expr( + :->, + Expr(:tuple, var), + chain_block_rewritten + ) + ) +end + + function replace_underscores(expr::Expr, replacement) found_underscore = false # if a @chain macrocall is found, only its first arg can be replaced if it's an # underscore, otherwise the macro insides are left untouched if expr.head == :macrocall && expr.args[1] == Symbol("@chain") - length(expr.args) < 3 && error("Malformed nested @chain macro") - expr.args[2] isa LineNumberNode || error("Malformed nested @chain macro") - arg3 = if expr.args[3] == Symbol("_") + expr.args[2] isa LineNumberNode || error("Expected LineNumberNode as arg 2 in nested @chain macro") + length(expr.args) < 3 && error("Nested @chain macro has no arguments.") + if length(expr.args) == 3 && is_empty_do(expr.args[3]) + doblock = expr.args[3] + call = doblock.args[1] + had_underscore, new_call = replace_underscores(call, replacement) + if !had_underscore + new_call = insert_first_arg(call, replacement) + end + doblock.args[1] = new_call + newexpr = rewrite_do_block_chain(doblock, escape = false) found_underscore = true - replacement else - expr.args[3] + arg3 = if expr.args[3] == Symbol("_") + found_underscore = true + replacement + else + expr.args[3] + end + newexpr = Expr(:macrocall, Symbol("@chain"), expr.args[2], arg3, expr.args[4:end]...) end - newexpr = Expr(:macrocall, Symbol("@chain"), expr.args[2], arg3, expr.args[4:end]...) # for all other expressions, their arguments are checked for underscores recursively # and replaced if any are found else diff --git a/test/runtests.jl b/test/runtests.jl index a5a2542..a51583d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -441,4 +441,52 @@ end @test 36 == @chain 1:3 begin @chain _ sum _ ^ 2 end +end + +@testset "empty do syntax" begin + + @test [4, 6, 8] == @chain map(1:3) do + _ + 1 + _ * 2 + end + + x = @chain 1:5 begin + @chain filter() do + _ + 1 + isodd + end + @chain map() do + _ + 2 + sqrt + end + end + @test x == sqrt.([2, 4] .+ 2) + + @test [[["ax", "by"]], [["cz"]]] == @chain " ax by \n cz " begin + split("\n") + @chain map() do + split("|") + @chain map() do + strip + split + end + end + end + + @test sum(sqrt.((1:10) .+ 1)) == @chain mapreduce(+, 1:10) do + _ + 1 + sqrt + end + + @test ["ho", "word"] == @chain begin + ["hello", "world"] + @chain map() do + collect + @chain filter() do + uppercase + _ ∉ ('E', 'L') + end + String + end + end end \ No newline at end of file