Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve build_function for CTarget() #112

Merged
merged 3 commits into from
Mar 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ function vars_to_pairs(name,vs, symsdict)
end

get_varnumber(varop, vars::Vector) = findfirst(x->isequal(x,varop),vars)
get_varnumber(varop, var) = isequal(var,varop) ? 0 : nothing

function numbered_expr(O::Symbolic,args...;varordering = args[1],offset = 0,
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)])
Expand All @@ -360,7 +361,7 @@ function numbered_expr(O::Symbolic,args...;varordering = args[1],offset = 0,
for j in 1:length(args)
i = get_varnumber(O,args[j])
if i !== nothing
return :($(rhsnames[j])[$(i+offset)])
return i==0 ? :($(rhsnames[j])) : :($(rhsnames[j])[$(i+offset)])
end
end
end
Expand All @@ -383,6 +384,53 @@ end
numbered_expr(c,args...;kwargs...) = c
numbered_expr(c::Num,args...;kwargs...) = error("Num found")


# Replace certain multiplication and power expressions so they form valid C code
# Extra factors of 1 are hopefully eliminated by the C compiler
function coperators(expr)
for e in expr.args
if e isa Expr
coperators(e)
end
end
# Introduce another factor 1 to prevent contraction of terms like "5 * t" to "5t" (not valid C code)
if expr.head==:call && expr.args[1]==:* && length(expr.args)==3 && isa(expr.args[2], Real) && isa(expr.args[3], Symbol)
push!(expr.args, 1)
# Power operator does not exist in C, replace by multiplication or "pow"
elseif expr.head==:call && expr.args[1]==:^
@assert length(expr.args)==3 "Don't know how to handle ^ operation with <> 2 arguments"
x = expr.args[2]
n = expr.args[3]
empty!(expr.args)
# Replace by multiplication/division if
# x is a symbol and n is a small integer
# x is a more complex expression and n is ±1
# n is exactly 0
if (isa(n,Integer) && ((isa(x, Symbol) && abs(n) <= 3) || abs(n) <= 1)) || n==0
if n >= 0
append!(expr.args, [:*, fill(x, n)...])
# fill up with factor 1 so this expr can still be a multiplication
while length(expr.args) < 3
push!(expr.args, 1)
end
else # inverse of the above
if n==-1
term = x
else
term = :( ($(x)) ^ ($(-n)))
coperators(term)
end
append!(expr.args, [:/, 1., term])
end
#... otherwise use "pow" function
else
append!(expr.args, [:pow, x, n])
end
end
expr
end


"""
Build function target: `CTarget`

Expand Down Expand Up @@ -487,14 +535,15 @@ function _build_function(target::CTarget, ex::AbstractArray, args...;
rhs = numbered_expr(value(ex[row, col]), args...;
lhsname = lhsname,
rhsnames = rhsnames,
offset = -1) |> string
offset = -1) |> coperators |> string # Filter through coperators to produce valid C code in more cases
push!(equations, string(lhs, " = ", rhs, ";"))
end
end

argstrs = join(vcat("double* $(lhsname)",[typeof(args[i])<:Array ? "double* $(rhsnames[i])" : "double $(rhsnames[i])" for i in 1:length(args)]),", ")

ccode = """
#include <math.h>
void $fname($(argstrs...)) {$([string("\n ", eqn) for eqn ∈ equations]...)\n}
"""

Expand Down
24 changes: 18 additions & 6 deletions test/build_targets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ expr = [a*x - x*y,-3y + x*y]
lhsname=:internal_var___du,
rhsnames=[:internal_var___u,:internal_var___p,:t]) ==
"""
#include <math.h>
void diffeqf(double* internal_var___du, double* internal_var___u, double* internal_var___p, double t) {
internal_var___du[0] = internal_var___p[0] * internal_var___u[0] + -1 * internal_var___u[0] * internal_var___u[1];
internal_var___du[1] = internal_var___u[0] * internal_var___u[1] + -3 * internal_var___u[1];
Expand Down Expand Up @@ -65,19 +66,30 @@ let
cfunc = build_function(expression, variables; target = Symbolics.CTarget(), expression = Val{true})

# Generated function should be out[0] = in[0], out[1] = in[1], out[2] = in[2], etc.
@test cfunc == "void diffeqf(double* du, double* RHS1) {\n du[0] = RHS1[0];\n du[1] = RHS1[1];\n du[2] = RHS1[2];\n du[3] = RHS1[3];\n du[4] = RHS1[4];\n du[5] = RHS1[5];\n du[6] = RHS1[6];\n du[7] = RHS1[7];\n du[8] = RHS1[8];\n du[9] = RHS1[9];\n du[10] = RHS1[10];\n du[11] = RHS1[11];\n}\n"
@test cfunc == "#include <math.h>\nvoid diffeqf(double* du, double* RHS1) {\n du[0] = RHS1[0];\n du[1] = RHS1[1];\n du[2] = RHS1[2];\n du[3] = RHS1[3];\n du[4] = RHS1[4];\n du[5] = RHS1[5];\n du[6] = RHS1[6];\n du[7] = RHS1[7];\n du[8] = RHS1[8];\n du[9] = RHS1[9];\n du[10] = RHS1[10];\n du[11] = RHS1[11];\n}\n"
end

# Scalar CTarget test
let
@variables x y z
expression = x + y + z
cfunc = build_function(expression, [x], [y], [z]; target = Symbolics.CTarget(), expression = Val{true})
@variables x y t
expression = x + y + t
cfunc = build_function(expression, [x], [y], t; target = Symbolics.CTarget(), expression = Val{true})

# Generated function should be out[0] = in1[0] + in2[0] + in3[0]
@test cfunc == "void diffeqf(double* du, double* RHS1, double* RHS2, double* RHS3) {\n du[0] = RHS1[0] + RHS2[0] + RHS3[0];\n}\n"
# Generated function should be out[0] = in1[0] + in2[0] + in3
@test cfunc == "#include <math.h>\nvoid diffeqf(double* du, double* RHS1, double* RHS2, double RHS3) {\n du[0] = RHS3 + RHS1[0] + RHS2[0];\n}\n"
end

# Scalar CTarget test with scalar multiplication and powers
let
@variables x y a t
expression = x^2 + y^-1 + sin(a)^3.5 + 2t
cfunc = build_function(expression, [x, y], [a], t; target = Symbolics.CTarget(), expression = Val{true})

# Generated function should avoid scalar multiplication of the form "4t" (currently done by adding another "* 1") and other invalid C syntax
@test cfunc == "#include <math.h>\nvoid diffeqf(double* du, double* RHS1, double* RHS2, double RHS3) {\n du[0] = 2 * RHS3 * 1 + pow(RHS1[0], 2) + 1.0 / RHS1[1] + pow(sin(RHS2[0]), 3.5);\n}\n"
end


# Matrix StanTarget test
let
@variables x[1:4] y[1:4] z[1:4]
Expand Down