diff --git a/src/variable.jl b/src/variable.jl index 1760724ed..a4ccc8e7e 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -606,7 +606,7 @@ function fast_substitute(expr, subs; operator = Nothing) args = let canfold = canfold map(args) do x x′ = fast_substitute(x, subs; operator) - canfold[] = canfold[] && !(x′ isa Symbolic) + canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′)) x′ end end @@ -633,7 +633,7 @@ function fast_substitute(expr, pair::Pair; operator = Nothing) args = let canfold = canfold map(args) do x x′ = fast_substitute(x, pair; operator) - canfold[] = canfold[] && !(x′ isa Symbolic) + canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′)) x′ end end @@ -645,6 +645,13 @@ function fast_substitute(expr, pair::Pair; operator = Nothing) metadata(expr)) end +function is_array_of_symbolics(x) + symbolic_type(x) == ArraySymbolic() && return true + symbolic_type(x) == ScalarSymbolic() && return false + x isa AbstractArray && + any(y -> symbolic_type(y) != NotSymbolic() || is_array_of_symbolics(y), x) +end + function getparent(x, val=_fail) maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing) if maybe_parent !== nothing diff --git a/test/arrays.jl b/test/arrays.jl index 11ce88441..119f7aec7 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -388,7 +388,7 @@ end lapu = wrap(lapu) lapv = wrap(lapv) - f, g = build_function(dtu, u, v, t, expression=Val{false}) + f, g = build_function(dtu, u, v, t, expression=Val{false}, nanmath = false) du = zeros(Num, 8, 8) f(du, u,v,t) @test isequal(collect(du), collect(dtu)) diff --git a/test/utils.jl b/test/utils.jl index b47c1146a..8203b41f6 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -153,4 +153,12 @@ end test_nested_derivative = Dx(Dt(Dt(u))) result = diff2term(Symbolics.value(test_nested_derivative)) @test typeof(result) === Symbolics.BasicSymbolic{Real} -end \ No newline at end of file +end + +@testset "`fast_substitute` inside array symbolics" begin + @variables x y z + @register_symbolic foo(a::AbstractArray, b) + ex = foo([x, y], z) + ex2 = Symbolics.fixpoint_sub(ex, Dict(y => 1.0, z => 2.0)) + @test isequal(ex2, foo([x, 1.0], 2.0)) +end