Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

empty do chain syntax #46

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 63 additions & 12 deletions src/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ function rewrite(expr, replacement)
(new_expr, replacement)
end

is_empty_do(x) = (x isa Expr &&
x.head == :do &&
length(x.args) == 2 &&
x.args[2] isa Expr &&
x.args[2].head == :-> &&
x.args[2].args[1].head == :tuple &&
length(x.args[2].args[1].args) == 0)



rewrite(l::LineNumberNode, replacement) = (l, replacement)

function rewrite_chain_block(firstpart, block)
Expand Down Expand Up @@ -144,13 +154,15 @@ macro chain(initial_value, block::Expr)
if !(block.head == :block)
block = Expr(:block, block)
end
rewrite_chain_block(initial_value, block)
esc(rewrite_chain_block(initial_value, block))
end



"""
@chain(initial_value, args...)

Rewrites a series of argments, either expressions or symbols, to feed the result
Rewrites a series of arguments, either expressions or symbols, to feed the result
of each line into the next one. The initial value is given by the first argument.

In all arguments, underscores are replaced by the argument's result.
Expand All @@ -168,7 +180,7 @@ x == sum(sqrt.(filter(!=(2), [1, 2, 3])))
```
"""
macro chain(initial_value, args...)
rewrite_chain_block(initial_value, Expr(:block, args...))
esc(rewrite_chain_block(initial_value, Expr(:block, args...)))
end

function rewrite_chain_block(block)
Expand Down Expand Up @@ -199,8 +211,6 @@ function rewrite_chain_block(block)
end

result = Expr(:let, Expr(:block), Expr(:block, rewritten_exprs..., replacement))

:($(esc(result)))
end

# if a line in a chain is a string, it can be parsed as a docstring
Expand Down Expand Up @@ -248,24 +258,65 @@ x == sum(sqrt.(filter(!=(2), [1, 2, 3])))
```
"""
macro chain(block::Expr)
rewrite_chain_block(block)
if is_empty_do(block)
esc(rewrite_do_block_chain(block))
else
esc(rewrite_chain_block(block))
end
end

function rewrite_do_block_chain(doblock; escape = true)
call = doblock.args[1]
if !(call isa Expr && call.head == :call)
error("Expected a call expression in do syntax")
end
var = gensym()
block = doblock.args[2].args[2]
if !(block isa Expr && block.head == :block)
error("Expected a block in do syntax")
end
chain_block_rewritten = rewrite_chain_block(var, block)
# dump(chain_block_rewritten)

Expr(
:do,
call,
Expr(
:->,
Expr(:tuple, var),
chain_block_rewritten
)
)
end


function replace_underscores(expr::Expr, replacement)
found_underscore = false

# if a @chain macrocall is found, only its first arg can be replaced if it's an
# underscore, otherwise the macro insides are left untouched
if expr.head == :macrocall && expr.args[1] == Symbol("@chain")
length(expr.args) < 3 && error("Malformed nested @chain macro")
expr.args[2] isa LineNumberNode || error("Malformed nested @chain macro")
arg3 = if expr.args[3] == Symbol("_")
expr.args[2] isa LineNumberNode || error("Expected LineNumberNode as arg 2 in nested @chain macro")
length(expr.args) < 3 && error("Nested @chain macro has no arguments.")
if length(expr.args) == 3 && is_empty_do(expr.args[3])
doblock = expr.args[3]
call = doblock.args[1]
had_underscore, new_call = replace_underscores(call, replacement)
if !had_underscore
new_call = insert_first_arg(call, replacement)
end
doblock.args[1] = new_call
newexpr = rewrite_do_block_chain(doblock, escape = false)
found_underscore = true
replacement
else
expr.args[3]
arg3 = if expr.args[3] == Symbol("_")
found_underscore = true
replacement
else
expr.args[3]
end
newexpr = Expr(:macrocall, Symbol("@chain"), expr.args[2], arg3, expr.args[4:end]...)
end
newexpr = Expr(:macrocall, Symbol("@chain"), expr.args[2], arg3, expr.args[4:end]...)
# for all other expressions, their arguments are checked for underscores recursively
# and replaced if any are found
else
Expand Down
48 changes: 48 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -441,4 +441,52 @@ end
@test 36 == @chain 1:3 begin
@chain _ sum _ ^ 2
end
end

@testset "empty do syntax" begin

@test [4, 6, 8] == @chain map(1:3) do
_ + 1
_ * 2
end

x = @chain 1:5 begin
@chain filter() do
_ + 1
isodd
end
@chain map() do
_ + 2
sqrt
end
end
@test x == sqrt.([2, 4] .+ 2)

@test [[["ax", "by"]], [["cz"]]] == @chain " ax by \n cz " begin
split("\n")
@chain map() do
split("|")
@chain map() do
strip
split
end
end
end

@test sum(sqrt.((1:10) .+ 1)) == @chain mapreduce(+, 1:10) do
_ + 1
sqrt
end

@test ["ho", "word"] == @chain begin
["hello", "world"]
@chain map() do
collect
@chain filter() do
uppercase
_ ∉ ('E', 'L')
end
String
end
end
end