From 4c0658464990ae2c6edcc512544c4656737598d1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 7 Jan 2025 14:41:45 +0530 Subject: [PATCH 1/5] fix: fix `fast_substitute` folding array of symbolics --- src/variable.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/variable.jl b/src/variable.jl index 1760724ed..a4ccc8e7e 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -606,7 +606,7 @@ function fast_substitute(expr, subs; operator = Nothing) args = let canfold = canfold map(args) do x x′ = fast_substitute(x, subs; operator) - canfold[] = canfold[] && !(x′ isa Symbolic) + canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′)) x′ end end @@ -633,7 +633,7 @@ function fast_substitute(expr, pair::Pair; operator = Nothing) args = let canfold = canfold map(args) do x x′ = fast_substitute(x, pair; operator) - canfold[] = canfold[] && !(x′ isa Symbolic) + canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′)) x′ end end @@ -645,6 +645,13 @@ function fast_substitute(expr, pair::Pair; operator = Nothing) metadata(expr)) end +function is_array_of_symbolics(x) + symbolic_type(x) == ArraySymbolic() && return true + symbolic_type(x) == ScalarSymbolic() && return false + x isa AbstractArray && + any(y -> symbolic_type(y) != NotSymbolic() || is_array_of_symbolics(y), x) +end + function getparent(x, val=_fail) maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing) if maybe_parent !== nothing From c00f09f48252bc37c246ffcdd8536c30e80d741e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 9 Jan 2025 14:12:38 +0530 Subject: [PATCH 2/5] test: add tests for `fast_subsitute` on function of array of symbolics --- test/utils.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/test/utils.jl b/test/utils.jl index b47c1146a..8203b41f6 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -153,4 +153,12 @@ end test_nested_derivative = Dx(Dt(Dt(u))) result = diff2term(Symbolics.value(test_nested_derivative)) @test typeof(result) === Symbolics.BasicSymbolic{Real} -end \ No newline at end of file +end + +@testset "`fast_substitute` inside array symbolics" begin + @variables x y z + @register_symbolic foo(a::AbstractArray, b) + ex = foo([x, y], z) + ex2 = Symbolics.fixpoint_sub(ex, Dict(y => 1.0, z => 2.0)) + @test isequal(ex2, foo([x, 1.0], 2.0)) +end From 1ea3df51c35a92ed069f35df1e167b5548ef0d6f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 9 Jan 2025 20:46:26 +0530 Subject: [PATCH 3/5] test: fix bruss test --- test/arrays.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/arrays.jl b/test/arrays.jl index 11ce88441..119f7aec7 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -388,7 +388,7 @@ end lapu = wrap(lapu) lapv = wrap(lapv) - f, g = build_function(dtu, u, v, t, expression=Val{false}) + f, g = build_function(dtu, u, v, t, expression=Val{false}, nanmath = false) du = zeros(Num, 8, 8) f(du, u,v,t) @test isequal(collect(du), collect(dtu)) From 689d673be115f3f981774659033c0931d53007d6 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 15 Jan 2025 14:41:09 +0530 Subject: [PATCH 4/5] build: bump SymbolicUtils compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 12ecc26cf..2db101d03 100644 --- a/Project.toml +++ b/Project.toml @@ -92,7 +92,7 @@ StaticArraysCore = "1.4" SymPy = "2.2" SymbolicIndexingInterface = "0.3.14" SymbolicLimits = "0.2.2" -SymbolicUtils = "3.7" +SymbolicUtils = "3.10" TermInterface = "2" julia = "1.10" From 9271d62dd304b4957901792c12b6c63a6606e79a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 15 Jan 2025 14:41:28 +0530 Subject: [PATCH 5/5] fix: unwrap arguments in `_filter_poly` --- src/solver/preprocess.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/solver/preprocess.jl b/src/solver/preprocess.jl index 7d6a21c6e..f506b434d 100644 --- a/src/solver/preprocess.jl +++ b/src/solver/preprocess.jl @@ -135,7 +135,7 @@ function _filter_poly(expr, var) subs[i_var] = im expr = unwrap(expr1 + i_var * expr2) - args = arguments(expr) + args = map(unwrap, arguments(expr)) oper = operation(expr) return subs, term(oper, args...) end @@ -208,7 +208,7 @@ function _filter_poly(expr, var) end end - args = arguments(expr) + args = map(unwrap, arguments(expr)) oper = operation(expr) expr = term(oper, args...) return subs, expr