diff --git a/README.md b/README.md index b35adb2..e7b86b8 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,31 @@ result = let end ``` +## Alternative one-argument syntax + +If your initial argument name is long and / or the chain's result is assigned to a long +variable, it can look cleaner if the initial value is moved into the chain. +Here is such a long expression: + +```julia +a_long_result_variable_name = @chain a_long_input_variable_name begin + do_something + do_something_else(parameter) + do_other_thing(parameter, _) +end +``` + +This is equivalent to the following expression: + +```julia +a_long_result_variable_name = @chain begin + a_long_input_variable_name + do_something + do_something_else(parameter) + do_other_thing(parameter, _) +end +``` + ## The `@aside` macro For debugging, it's often useful to look at values in the middle of a pipeline. diff --git a/src/Chain.jl b/src/Chain.jl index 6990b90..600e58d 100644 --- a/src/Chain.jl +++ b/src/Chain.jl @@ -72,8 +72,93 @@ function rewrite_chain_block(firstpart, block) :($(esc(result))) end -macro chain(firstpart, block) - rewrite_chain_block(firstpart, block) +""" + @chain(initial_value, block::Expr) + +Rewrites a block expression to feed the result of each line into the next one. +The initial value is given by the first argument. + +In all lines, underscores are replaced by the previous line's result. +If there are no underscores and the expression is a symbol, the symbol is rewritten +to a function call with the previous result as the only argument. +If there are no underscores and the expression is a function call or a macrocall, +the call has the previous result prepended as the first argument. + +Example: + +``` +x = @chain [1, 2, 3] begin + filter(!=(2), _) + sqrt.(_) + sum +end +x == sum(sqrt.(filter(!=(2), [1, 2, 3]))) +``` +""" +macro chain(initial_value, block::Expr) + rewrite_chain_block(initial_value, block) +end + +function rewrite_chain_block(block) + if !(block isa Expr && block.head == :block) + error("Only argument of single-argument @chain must be a begin / end block") + end + + block_expressions = block.args + isempty(block_expressions) && error("No expressions found in chain block.") + + # assign first line to first gensym variable + firstvar = gensym() + rewritten_exprs = [] + replacement = firstvar + + did_first = false + for expr in block_expressions + # could be and expression first or a LineNumberNode, so a bit convoluted + # we just do the firstvar transformation for the first non LineNumberNode + # we encounter + if !(did_first || expr isa LineNumberNode) + expr = Expr(Symbol("="), firstvar, expr) + did_first = true + push!(rewritten_exprs, expr) + continue + end + + rewritten, replacement = rewrite(expr, replacement) + push!(rewritten_exprs, rewritten) + end + + result = Expr(:let, Expr(:block), Expr(:block, rewritten_exprs...)) + + :($(esc(result))) +end + +""" + @chain(block::Expr) + +Rewrites a block expression to feed the result of each line into the next one. +The first line serves as the initial value and is not rewritten. + +In all other lines, underscores are replaced by the previous line's result. +If there are no underscores and the expression is a symbol, the symbol is rewritten +to a function call with the previous result as the only argument. +If there are no underscores and the expression is a function call or a macrocall, +the call has the previous result prepended as the first argument. + +Example: + +``` +x = @chain begin + [1, 2, 3] + filter(!=(2), _) + sqrt.(_) + sum +end +x == sum(sqrt.(filter(!=(2), [1, 2, 3]))) +``` +""" +macro chain(block::Expr) + rewrite_chain_block(block) end function replace_underscores(expr::Expr, replacement) diff --git a/test/runtests.jl b/test/runtests.jl index 21a159f..aa15d4a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -151,4 +151,61 @@ end @broadcastminus(2.5, _) end @test yyy == broadcast(-, 2.5, xxx) +end + +@testset "single arg version" begin + x = [1, 2, 3] + + xx = @chain begin + x + end + @test xx == x + + # this has a different internal structure (one LineNumberNode missing I think) + @test x == @chain begin + x + end + + @test sum(x) == @chain begin + x + sum + end + + y = @chain begin + x + sum + end + @test y == sum(x) + + z = @chain begin + x + @. sqrt + sum(_) + end + @test z == sum(sqrt.(x)) + + @test sum == @chain begin + sum + end +end + +@testset "invalid single arg versions" begin + # empty + @test_throws LoadError eval(quote + @chain begin + end + end) + + # rvalue _ errors + @test_throws ErrorException eval(quote + @chain begin + _ + end + end) + + @test_throws ErrorException eval(quote + @chain begin + sum(_) + end + end) end \ No newline at end of file