Skip to content

Commit

Permalink
Merge pull request #112 from philbit/buildfunction_c_patch
Browse files Browse the repository at this point in the history
Improve build_function for CTarget()
  • Loading branch information
ChrisRackauckas authored Mar 12, 2021
2 parents 751031a + d6dd933 commit 37f2166
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 8 deletions.
53 changes: 51 additions & 2 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ function vars_to_pairs(name,vs, symsdict)
end

get_varnumber(varop, vars::Vector) = findfirst(x->isequal(x,varop),vars)
get_varnumber(varop, var) = isequal(var,varop) ? 0 : nothing

function numbered_expr(O::Symbolic,args...;varordering = args[1],offset = 0,
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)])
Expand All @@ -360,7 +361,7 @@ function numbered_expr(O::Symbolic,args...;varordering = args[1],offset = 0,
for j in 1:length(args)
i = get_varnumber(O,args[j])
if i !== nothing
return :($(rhsnames[j])[$(i+offset)])
return i==0 ? :($(rhsnames[j])) : :($(rhsnames[j])[$(i+offset)])
end
end
end
Expand All @@ -383,6 +384,53 @@ end
numbered_expr(c,args...;kwargs...) = c
numbered_expr(c::Num,args...;kwargs...) = error("Num found")


# Replace certain multiplication and power expressions so they form valid C code
# Extra factors of 1 are hopefully eliminated by the C compiler
function coperators(expr)
for e in expr.args
if e isa Expr
coperators(e)
end
end
# Introduce another factor 1 to prevent contraction of terms like "5 * t" to "5t" (not valid C code)
if expr.head==:call && expr.args[1]==:* && length(expr.args)==3 && isa(expr.args[2], Real) && isa(expr.args[3], Symbol)
push!(expr.args, 1)
# Power operator does not exist in C, replace by multiplication or "pow"
elseif expr.head==:call && expr.args[1]==:^
@assert length(expr.args)==3 "Don't know how to handle ^ operation with <> 2 arguments"
x = expr.args[2]
n = expr.args[3]
empty!(expr.args)
# Replace by multiplication/division if
# x is a symbol and n is a small integer
# x is a more complex expression and n is ±1
# n is exactly 0
if (isa(n,Integer) && ((isa(x, Symbol) && abs(n) <= 3) || abs(n) <= 1)) || n==0
if n >= 0
append!(expr.args, [:*, fill(x, n)...])
# fill up with factor 1 so this expr can still be a multiplication
while length(expr.args) < 3
push!(expr.args, 1)
end
else # inverse of the above
if n==-1
term = x
else
term = :( ($(x)) ^ ($(-n)))
coperators(term)
end
append!(expr.args, [:/, 1., term])
end
#... otherwise use "pow" function
else
append!(expr.args, [:pow, x, n])
end
end
expr
end


"""
Build function target: `CTarget`
Expand Down Expand Up @@ -487,14 +535,15 @@ function _build_function(target::CTarget, ex::AbstractArray, args...;
rhs = numbered_expr(value(ex[row, col]), args...;
lhsname = lhsname,
rhsnames = rhsnames,
offset = -1) |> string
offset = -1) |> coperators |> string # Filter through coperators to produce valid C code in more cases
push!(equations, string(lhs, " = ", rhs, ";"))
end
end

argstrs = join(vcat("double* $(lhsname)",[typeof(args[i])<:Array ? "double* $(rhsnames[i])" : "double $(rhsnames[i])" for i in 1:length(args)]),", ")

ccode = """
#include <math.h>
void $fname($(argstrs...)) {$([string("\n ", eqn) for eqn equations]...)\n}
"""

Expand Down
24 changes: 18 additions & 6 deletions test/build_targets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ expr = [a*x - x*y,-3y + x*y]
lhsname=:internal_var___du,
rhsnames=[:internal_var___u,:internal_var___p,:t]) ==
"""
#include <math.h>
void diffeqf(double* internal_var___du, double* internal_var___u, double* internal_var___p, double t) {
internal_var___du[0] = internal_var___p[0] * internal_var___u[0] + -1 * internal_var___u[0] * internal_var___u[1];
internal_var___du[1] = internal_var___u[0] * internal_var___u[1] + -3 * internal_var___u[1];
Expand Down Expand Up @@ -65,19 +66,30 @@ let
cfunc = build_function(expression, variables; target = Symbolics.CTarget(), expression = Val{true})

# Generated function should be out[0] = in[0], out[1] = in[1], out[2] = in[2], etc.
@test cfunc == "void diffeqf(double* du, double* RHS1) {\n du[0] = RHS1[0];\n du[1] = RHS1[1];\n du[2] = RHS1[2];\n du[3] = RHS1[3];\n du[4] = RHS1[4];\n du[5] = RHS1[5];\n du[6] = RHS1[6];\n du[7] = RHS1[7];\n du[8] = RHS1[8];\n du[9] = RHS1[9];\n du[10] = RHS1[10];\n du[11] = RHS1[11];\n}\n"
@test cfunc == "#include <math.h>\nvoid diffeqf(double* du, double* RHS1) {\n du[0] = RHS1[0];\n du[1] = RHS1[1];\n du[2] = RHS1[2];\n du[3] = RHS1[3];\n du[4] = RHS1[4];\n du[5] = RHS1[5];\n du[6] = RHS1[6];\n du[7] = RHS1[7];\n du[8] = RHS1[8];\n du[9] = RHS1[9];\n du[10] = RHS1[10];\n du[11] = RHS1[11];\n}\n"
end

# Scalar CTarget test
let
@variables x y z
expression = x + y + z
cfunc = build_function(expression, [x], [y], [z]; target = Symbolics.CTarget(), expression = Val{true})
@variables x y t
expression = x + y + t
cfunc = build_function(expression, [x], [y], t; target = Symbolics.CTarget(), expression = Val{true})

# Generated function should be out[0] = in1[0] + in2[0] + in3[0]
@test cfunc == "void diffeqf(double* du, double* RHS1, double* RHS2, double* RHS3) {\n du[0] = RHS1[0] + RHS2[0] + RHS3[0];\n}\n"
# Generated function should be out[0] = in1[0] + in2[0] + in3
@test cfunc == "#include <math.h>\nvoid diffeqf(double* du, double* RHS1, double* RHS2, double RHS3) {\n du[0] = RHS3 + RHS1[0] + RHS2[0];\n}\n"
end

# Scalar CTarget test with scalar multiplication and powers
let
@variables x y a t
expression = x^2 + y^-1 + sin(a)^3.5 + 2t
cfunc = build_function(expression, [x, y], [a], t; target = Symbolics.CTarget(), expression = Val{true})

# Generated function should avoid scalar multiplication of the form "4t" (currently done by adding another "* 1") and other invalid C syntax
@test cfunc == "#include <math.h>\nvoid diffeqf(double* du, double* RHS1, double* RHS2, double RHS3) {\n du[0] = 2 * RHS3 * 1 + pow(RHS1[0], 2) + 1.0 / RHS1[1] + pow(sin(RHS2[0]), 3.5);\n}\n"
end


# Matrix StanTarget test
let
@variables x[1:4] y[1:4] z[1:4]
Expand Down

0 comments on commit 37f2166

Please sign in to comment.