diff --git a/README.md b/README.md index ddf1c26..e7c470e 100644 --- a/README.md +++ b/README.md @@ -142,6 +142,15 @@ a_long_result_variable_name = @chain begin end ``` +## One-liner syntax + +You can also use `@chain` as a one-liner, where no begin-end block is necessary. +This works well for short sequences that are still easy to parse visually without being on separate lines. + +```julia +@chain 1:10 filter(isodd, _) sum sqrt +``` + ## The `@aside` macro For debugging, it's often useful to look at values in the middle of a pipeline. diff --git a/src/Chain.jl b/src/Chain.jl index 7cd6139..5972a50 100644 --- a/src/Chain.jl +++ b/src/Chain.jl @@ -13,9 +13,9 @@ function insertionerror(expr) error( """Can't insert a first argument into: $expr. - + First argument insertion works with expressions like these, where [Module.SubModule.] is optional: - + [Module.SubModule.]func [Module.SubModule.]func(args...) [Module.SubModule.]func(args...; kwargs...) @@ -71,7 +71,7 @@ function insert_first_arg(e::Expr, firstarg) # @macro(args...) --> @macro(firstarg, args...) elseif head == :macrocall && - (is_moduled_symbol(args[1]) || args[1] isa Symbol) && + (is_moduled_symbol(args[1]) || args[1] isa Symbol) && args[2] isa LineNumberNode if args[1] == Symbol("@__dot__") @@ -108,10 +108,6 @@ end rewrite(l::LineNumberNode, replacement) = (l, replacement) function rewrite_chain_block(firstpart, block) - if !(block isa Expr && block.head == :block) - error("Second argument of @chain must be a begin / end block") - end - block_expressions = block.args # empty chain returns firstpart @@ -156,14 +152,37 @@ x == sum(sqrt.(filter(!=(2), [1, 2, 3]))) ``` """ macro chain(initial_value, block::Expr) + if !(block.head == :block) + block = Expr(:block, block) + end rewrite_chain_block(initial_value, block) end -function rewrite_chain_block(block) - if !(block isa Expr && block.head == :block) - error("Only argument of single-argument @chain must be a begin / end block") - end +""" + @chain(initial_value, args...) + +Rewrites a series of argments, either expressions or symbols, to feed the result +of each line into the next one. The initial value is given by the first argument. + +In all arguments, underscores are replaced by the argument's result. +If there are no underscores and the argument is a symbol, the symbol is rewritten +to a function call with the previous result as the only argument. +If there are no underscores and the argument is a function call or a macrocall, +the call has the previous result prepended as the first argument. + +Example: +``` +x = @chain [1, 2, 3] filter(!=(2), _) sqrt.(_) sum + +x == sum(sqrt.(filter(!=(2), [1, 2, 3]))) +``` +""" +macro chain(initial_value, args...) + rewrite_chain_block(initial_value, Expr(:block, args...)) +end + +function rewrite_chain_block(block) block_expressions = block.args isempty(block_expressions) && error("No expressions found in chain block.") diff --git a/test/runtests.jl b/test/runtests.jl index 7f44162..7db3787 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -58,17 +58,51 @@ end @test y == sqrt(sum(x) + 3 - 7) end +@testset "no begin" begin + x = [1, 2, 3] + y = @chain x sum + @test y == 6 + + f() = 1 + y = @chain f() first + @test y == 1 + + y = @chain x sum(_) max(0, _) first + @test y == 6 + + y = @chain 1 (t -> t + 1)() + @test y == 2 + + y = @chain 1 (t -> t + 1)() first max(0, _) + @test y == 2 + + y = @chain 1 (==(2)) + @test y == false + + y = @chain 1 (==(2)) first (==(false)) + @test y == true + + y = @chain 1 (_ + 1) + @test y == 2 + + y = @chain 1 (_ + 1) first max(0, _) + @test y == 2 + + # the begin block will be different from the normal chain block here + # only the last statement matters + y = @chain x begin + _ .+ 1 + _ .+ 2 + end sum + @test y == sum(x .+ 2) +end + @testset "invalid invocations" begin # just one argument @test_throws LoadError eval(quote @chain [1, 2, 3] end) - # no begin block - @test_throws LoadError eval(quote - @chain [1, 2, 3] sum - end) - # let block @test_throws LoadError eval(quote @chain [1, 2, 3] let @@ -232,7 +266,7 @@ end # issue 13 @testset "broadcasting calls" begin - + xs = [1, 2, 3] ys = @chain xs begin sin.() @@ -275,10 +309,10 @@ module LocalModule macro sin(exp) :(sin($(esc(exp)))) end - + macro broadcastminus(exp1, exp2) :(broadcast(-, $(esc(exp1)), $(esc(exp2)))) - end + end module SubModule function square(xs) @@ -294,10 +328,10 @@ module LocalModule macro sin(exp) :(sin($(esc(exp)))) end - + macro broadcastminus(exp1, exp2) :(broadcast(-, $(esc(exp1)), $(esc(exp2)))) - end + end end end