From 0ea30bf4b89099be98c7a24638391b131e96775a Mon Sep 17 00:00:00 2001 From: Karl Wessel Date: Fri, 8 Nov 2024 01:14:54 +0100 Subject: [PATCH 1/3] add flag for activating robust calculation of expand_derivatives --- src/diff.jl | 38 +++++++++++++++++++------------------- test/diff.jl | 30 ++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 19 deletions(-) diff --git a/src/diff.jl b/src/diff.jl index 891e2cde2..f9da8c071 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -180,12 +180,12 @@ julia> dfx=expand_derivatives(Dx(f)) (k*((2abs(x - y)) / y - 2z)*IfElse.ifelse(signbit(x - y), -1, 1)) / y ``` """ -function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) +function expand_derivatives(O::Symbolic, simplify=false; robust=false, occurrences=nothing) if iscall(O) && isa(operation(O), Differential) arg = only(arguments(O)) - arg = expand_derivatives(arg, false) + arg = expand_derivatives(arg, false; robust) - if occurrences == nothing + if robust || occurrences == nothing occurrences = occursin_info(operation(O).x, arg) end @@ -202,14 +202,14 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) return D(arg) # base case if any argument is directly equal to the i.v. else return sum(inner_args, init=0) do a - return expand_derivatives(Differential(a)(arg)) * - expand_derivatives(D(a)) + return expand_derivatives(Differential(a)(arg); robust) * + expand_derivatives(D(a); robust) end end elseif op === (IfElse.ifelse) args = arguments(arg) O = op(args[1], D(args[2]), D(args[3])) - return expand_derivatives(O, simplify; occurrences) + return expand_derivatives(O, simplify; robust, occurrences) elseif isa(op, Differential) # The recursive expand_derivatives was not able to remove # a nested Differential. We can attempt to differentiate the @@ -218,12 +218,12 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) if isequal(op.x, D.x) return D(arg) else - inner = expand_derivatives(D(arguments(arg)[1]), false) + inner = expand_derivatives(D(arguments(arg)[1]), false; robust) # if the inner expression is not expandable either, return if iscall(inner) && operation(inner) isa Differential return D(arg) else - return expand_derivatives(op(inner), simplify) + return expand_derivatives(op(inner), simplify; robust) end end elseif isa(op, Integral) @@ -231,7 +231,7 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) domain = op.domain.domain a, b = DomainSets.endpoints(domain) c = 0 - inner_function = expand_derivatives(arguments(arg)[1]) + inner_function = expand_derivatives(arguments(arg)[1]; robust) if iscall(value(a)) t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a))) t2 = D(a) @@ -242,7 +242,7 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) t2 = D(b) c += t1*t2 end - inner = expand_derivatives(D(arguments(arg)[1])) + inner = expand_derivatives(D(arguments(arg)[1]); robust) c += op(inner) return value(c) end @@ -254,7 +254,7 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) c = 0 for i in 1:l - t2 = expand_derivatives(D(inner_args[i]),false, occurrences=arguments(occurrences)[i]) + t2 = expand_derivatives(D(inner_args[i]),false; robust, occurrences=arguments(occurrences)[i]) x = if _iszero(t2) t2 @@ -286,23 +286,23 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) return simplify ? SymbolicUtils.simplify(x) : x end elseif iscall(O) && isa(operation(O), Integral) - return operation(O)(expand_derivatives(arguments(O)[1])) + return operation(O)(expand_derivatives(arguments(O)[1]; robust)) elseif !hasderiv(O) return O else - args = map(a->expand_derivatives(a, false), arguments(O)) + args = map(a->expand_derivatives(a, false; robust), arguments(O)) O1 = operation(O)(args...) return simplify ? SymbolicUtils.simplify(O1) : O1 end end -function expand_derivatives(n::Num, simplify=false; occurrences=nothing) - wrap(expand_derivatives(value(n), simplify; occurrences=occurrences)) +function expand_derivatives(n::Num, simplify=false; robust=false, occurrences=nothing) + wrap(expand_derivatives(value(n), simplify; robust, occurrences)) end -function expand_derivatives(n::Complex{Num}, simplify=false; occurrences=nothing) - wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; occurrences=occurrences), - expand_derivatives(imag(n), simplify; occurrences=occurrences))) +function expand_derivatives(n::Complex{Num}, simplify=false; robust=false, occurrences=nothing) + wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; robust, occurrences), + expand_derivatives(imag(n), simplify; robust, occurrences))) end -expand_derivatives(x, simplify=false; occurrences=nothing) = x +expand_derivatives(x, simplify=false; robust=false, occurrences=nothing) = x _iszero(x) = false _isone(x) = false diff --git a/test/diff.jl b/test/diff.jl index d40fa1185..76833b8bc 100644 --- a/test/diff.jl +++ b/test/diff.jl @@ -349,6 +349,36 @@ let @test isequal(expand_derivatives(Differential(t)(t^2 + im*t)), 2t + im) end +# 1262 +# +let + @variables t b(t) + D = Differential(t) + expr = b - ((D(b))^2) * D(D(b)) + expr2 = D(expr) + @test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true)) + @test_throws BoundsError expand_derivatives(expr2) + @test isequal(expand_derivatives(expr2; robust=true), D(b) - (D(b)^2)*D(D(D(b))) - 2D(b)*(D(D(b))^2)) +end + +# 1126 +# +let + @syms y f(y) g(y) h(y) + D = Differential(y) + + expr_gen = (fun) -> D(D(((-D(D(fun))) / g(y)))) + + expr = expr_gen(g(y)) + @test_broken isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true)) + expr = expr_gen(h(y)) + @test_broken isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true)) + + expected = substitute(expand_derivatives(expr; robust=true), h(y) => f(y)) + expr = expr_gen(f(y)) + @test_throws BoundsError expand_derivatives(expr) + @test isequal(expand(expand_derivatives(expr; robust=true)), expected) +end # Check `is_derivative` function let From 89ac04e9e01e8b49177a2a7a013c77445fc1631d Mon Sep 17 00:00:00 2001 From: Karl Wessel Date: Tue, 12 Nov 2024 00:51:51 +0100 Subject: [PATCH 2/3] make sure to expand differentials in subtrees only once --- src/diff.jl | 206 ++++++++++++++++++++++++++------------------------- test/diff.jl | 9 +-- 2 files changed, 109 insertions(+), 106 deletions(-) diff --git a/src/diff.jl b/src/diff.jl index f9da8c071..a2ec1924a 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -150,6 +150,109 @@ function recursive_hasoperator(op, O) end end +function executediff(D, arg, simplify=false; robust=false, occurrences=nothing) + if robust || occurrences == nothing + occurrences = occursin_info(D.x, arg) + end + + _isfalse(occurrences) && return 0 + occurrences isa Bool && return 1 # means it's a `true` + + if !iscall(arg) + return D(arg) # Cannot expand + elseif (op = operation(arg); issym(op)) + inner_args = arguments(arg) + if any(isequal(D.x), inner_args) + return D(arg) # base case if any argument is directly equal to the i.v. + else + return sum(inner_args, init=0) do a + return executediff(Differential(a), arg; robust) * + executediff(D, a; robust) + end + end + elseif op === (IfElse.ifelse) + args = arguments(arg) + O = op(args[1], + executediff(D, args[2], simplify; robust, occurrences=arguments(occurrences)[2]), + executediff(D, args[3], simplify; robust, occurrences=arguments(occurrences)[3])) + return O + elseif isa(op, Differential) + # The recursive expand_derivatives was not able to remove + # a nested Differential. We can attempt to differentiate the + # inner expression wrt to the outer iv. And leave the + # unexpandable Differential outside. + if isequal(op.x, D.x) + return D(arg) + else + inner = executediff(D, arguments(arg)[1], false; robust) + # if the inner expression is not expandable either, return + if iscall(inner) && operation(inner) isa Differential + return D(arg) + else + return expand_derivatives(op(inner), simplify; robust) # TODO + end + end + elseif isa(op, Integral) + if isa(op.domain.domain, AbstractInterval) + domain = op.domain.domain + a, b = DomainSets.endpoints(domain) + c = 0 + inner_function = expand_derivatives(arguments(arg)[1]; robust) # TODO + if iscall(value(a)) + t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a))) + t2 = D(a) + c -= t1*t2 + end + if iscall(value(b)) + t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(b))) + t2 = D(b) + c += t1*t2 + end + inner = executediff(D, arguments(arg)[1]; robust) + c += op(inner) + return value(c) + end + end + + inner_args = arguments(arg) + l = length(inner_args) + exprs = [] + c = 0 + + for i in 1:l + t2 = executediff(D, inner_args[i],false; robust, occurrences=arguments(occurrences)[i]) + + x = if _iszero(t2) + t2 + elseif _isone(t2) + d = derivative_idx(arg, i) + d isa NoDeriv ? D(arg) : d + else + t1 = derivative_idx(arg, i) + t1 = t1 isa NoDeriv ? D(arg) : t1 + t1 * t2 + end + + if _iszero(x) + continue + elseif x isa Symbolic + push!(exprs, x) + else + c += x + end + end + + if isempty(exprs) + return c + elseif length(exprs) == 1 + term = (simplify ? SymbolicUtils.simplify(exprs[1]) : exprs[1]) + return _iszero(c) ? term : c + term + else + x = +((!_iszero(c) ? vcat(c, exprs) : exprs)...) + return simplify ? SymbolicUtils.simplify(x) : x + end +end + """ $(SIGNATURES) @@ -184,107 +287,6 @@ function expand_derivatives(O::Symbolic, simplify=false; robust=false, occurrenc if iscall(O) && isa(operation(O), Differential) arg = only(arguments(O)) arg = expand_derivatives(arg, false; robust) - - if robust || occurrences == nothing - occurrences = occursin_info(operation(O).x, arg) - end - - _isfalse(occurrences) && return 0 - occurrences isa Bool && return 1 # means it's a `true` - - D = operation(O) - - if !iscall(arg) - return D(arg) # Cannot expand - elseif (op = operation(arg); issym(op)) - inner_args = arguments(arg) - if any(isequal(D.x), inner_args) - return D(arg) # base case if any argument is directly equal to the i.v. - else - return sum(inner_args, init=0) do a - return expand_derivatives(Differential(a)(arg); robust) * - expand_derivatives(D(a); robust) - end - end - elseif op === (IfElse.ifelse) - args = arguments(arg) - O = op(args[1], D(args[2]), D(args[3])) - return expand_derivatives(O, simplify; robust, occurrences) - elseif isa(op, Differential) - # The recursive expand_derivatives was not able to remove - # a nested Differential. We can attempt to differentiate the - # inner expression wrt to the outer iv. And leave the - # unexpandable Differential outside. - if isequal(op.x, D.x) - return D(arg) - else - inner = expand_derivatives(D(arguments(arg)[1]), false; robust) - # if the inner expression is not expandable either, return - if iscall(inner) && operation(inner) isa Differential - return D(arg) - else - return expand_derivatives(op(inner), simplify; robust) - end - end - elseif isa(op, Integral) - if isa(op.domain.domain, AbstractInterval) - domain = op.domain.domain - a, b = DomainSets.endpoints(domain) - c = 0 - inner_function = expand_derivatives(arguments(arg)[1]; robust) - if iscall(value(a)) - t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a))) - t2 = D(a) - c -= t1*t2 - end - if iscall(value(b)) - t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(b))) - t2 = D(b) - c += t1*t2 - end - inner = expand_derivatives(D(arguments(arg)[1]); robust) - c += op(inner) - return value(c) - end - end - - inner_args = arguments(arg) - l = length(inner_args) - exprs = [] - c = 0 - - for i in 1:l - t2 = expand_derivatives(D(inner_args[i]),false; robust, occurrences=arguments(occurrences)[i]) - - x = if _iszero(t2) - t2 - elseif _isone(t2) - d = derivative_idx(arg, i) - d isa NoDeriv ? D(arg) : d - else - t1 = derivative_idx(arg, i) - t1 = t1 isa NoDeriv ? D(arg) : t1 - t1 * t2 - end - - if _iszero(x) - continue - elseif x isa Symbolic - push!(exprs, x) - else - c += x - end - end - - if isempty(exprs) - return c - elseif length(exprs) == 1 - term = (simplify ? SymbolicUtils.simplify(exprs[1]) : exprs[1]) - return _iszero(c) ? term : c + term - else - x = +((!_iszero(c) ? vcat(c, exprs) : exprs)...) - return simplify ? SymbolicUtils.simplify(x) : x - end elseif iscall(O) && isa(operation(O), Integral) return operation(O)(expand_derivatives(arguments(O)[1]; robust)) elseif !hasderiv(O) @@ -294,6 +296,8 @@ function expand_derivatives(O::Symbolic, simplify=false; robust=false, occurrenc O1 = operation(O)(args...) return simplify ? SymbolicUtils.simplify(O1) : O1 end + + executediff(operation(O), arg, simplify; robust, occurrences) end function expand_derivatives(n::Num, simplify=false; robust=false, occurrences=nothing) wrap(expand_derivatives(value(n), simplify; robust, occurrences)) diff --git a/test/diff.jl b/test/diff.jl index 76833b8bc..4a8484c1c 100644 --- a/test/diff.jl +++ b/test/diff.jl @@ -357,8 +357,8 @@ let expr = b - ((D(b))^2) * D(D(b)) expr2 = D(expr) @test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true)) - @test_throws BoundsError expand_derivatives(expr2) @test isequal(expand_derivatives(expr2; robust=true), D(b) - (D(b)^2)*D(D(D(b))) - 2D(b)*(D(D(b))^2)) + @test isequal(expand_derivatives(expr2; robust=true), expand_derivatives(expr2)) end # 1126 @@ -370,14 +370,13 @@ let expr_gen = (fun) -> D(D(((-D(D(fun))) / g(y)))) expr = expr_gen(g(y)) - @test_broken isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true)) + @test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true)) expr = expr_gen(h(y)) - @test_broken isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true)) + @test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true)) expected = substitute(expand_derivatives(expr; robust=true), h(y) => f(y)) expr = expr_gen(f(y)) - @test_throws BoundsError expand_derivatives(expr) - @test isequal(expand(expand_derivatives(expr; robust=true)), expected) + @test isequal(expand(expand_derivatives(expr)), expand(expand_derivatives(expr; robust=true))) end # Check `is_derivative` function From 3bf685a71d4f9e98e247cded435a66a134dd4203 Mon Sep 17 00:00:00 2001 From: Karl Wessel Date: Fri, 15 Nov 2024 01:58:22 +0100 Subject: [PATCH 3/3] remove robust flag --- src/diff.jl | 66 +++++++++++++++++++++++++++++++--------------------- test/diff.jl | 13 +++++------ 2 files changed, 46 insertions(+), 33 deletions(-) diff --git a/src/diff.jl b/src/diff.jl index a2ec1924a..ccf48ea85 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -150,8 +150,25 @@ function recursive_hasoperator(op, O) end end -function executediff(D, arg, simplify=false; robust=false, occurrences=nothing) - if robust || occurrences == nothing +""" + executediff(D, arg, simplify=false; occurrences=nothing) + +Apply the passed Differential D on the passed argument. + +This function differs to `expand_derivatives` in that in only expands the +passed differential and not any other Differentials it encounters. + +# Arguments +- `D::Differential`: The differential to apply +- `arg::Symbolic`: The symbolic expression to apply the differential on. +- `simplify::Bool=false`: Whether to simplify the resulting expression using + [`SymbolicUtils.simplify`](@ref). +- `occurrences=nothing`: Information about the occurrences of the independent + variable in the argument of the derivative. This is used internally for + optimization purposes. +""" +function executediff(D, arg, simplify=false; occurrences=nothing) + if occurrences == nothing occurrences = occursin_info(D.x, arg) end @@ -166,15 +183,15 @@ function executediff(D, arg, simplify=false; robust=false, occurrences=nothing) return D(arg) # base case if any argument is directly equal to the i.v. else return sum(inner_args, init=0) do a - return executediff(Differential(a), arg; robust) * - executediff(D, a; robust) + return executediff(Differential(a), arg) * + executediff(D, a) end end elseif op === (IfElse.ifelse) args = arguments(arg) O = op(args[1], - executediff(D, args[2], simplify; robust, occurrences=arguments(occurrences)[2]), - executediff(D, args[3], simplify; robust, occurrences=arguments(occurrences)[3])) + executediff(D, args[2], simplify; occurrences=arguments(occurrences)[2]), + executediff(D, args[3], simplify; occurrences=arguments(occurrences)[3])) return O elseif isa(op, Differential) # The recursive expand_derivatives was not able to remove @@ -184,12 +201,13 @@ function executediff(D, arg, simplify=false; robust=false, occurrences=nothing) if isequal(op.x, D.x) return D(arg) else - inner = executediff(D, arguments(arg)[1], false; robust) + inner = executediff(D, arguments(arg)[1], false) # if the inner expression is not expandable either, return if iscall(inner) && operation(inner) isa Differential return D(arg) else - return expand_derivatives(op(inner), simplify; robust) # TODO + # otherwise give the nested Differential another try + return executediff(op, inner, simplify) end end elseif isa(op, Integral) @@ -197,7 +215,7 @@ function executediff(D, arg, simplify=false; robust=false, occurrences=nothing) domain = op.domain.domain a, b = DomainSets.endpoints(domain) c = 0 - inner_function = expand_derivatives(arguments(arg)[1]; robust) # TODO + inner_function = arguments(arg)[1] if iscall(value(a)) t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a))) t2 = D(a) @@ -208,7 +226,7 @@ function executediff(D, arg, simplify=false; robust=false, occurrences=nothing) t2 = D(b) c += t1*t2 end - inner = executediff(D, arguments(arg)[1]; robust) + inner = executediff(D, arguments(arg)[1]) c += op(inner) return value(c) end @@ -220,7 +238,7 @@ function executediff(D, arg, simplify=false; robust=false, occurrences=nothing) c = 0 for i in 1:l - t2 = executediff(D, inner_args[i],false; robust, occurrences=arguments(occurrences)[i]) + t2 = executediff(D, inner_args[i],false; occurrences=arguments(occurrences)[i]) x = if _iszero(t2) t2 @@ -265,9 +283,6 @@ and other derivative rules to expand any derivatives it encounters. - `O::Symbolic`: The symbolic expression to expand. - `simplify::Bool=false`: Whether to simplify the resulting expression using [`SymbolicUtils.simplify`](@ref). -- `occurrences=nothing`: Information about the occurrences of the independent - variable in the argument of the derivative. This is used internally for - optimization purposes. # Examples ```jldoctest @@ -283,30 +298,29 @@ julia> dfx=expand_derivatives(Dx(f)) (k*((2abs(x - y)) / y - 2z)*IfElse.ifelse(signbit(x - y), -1, 1)) / y ``` """ -function expand_derivatives(O::Symbolic, simplify=false; robust=false, occurrences=nothing) +function expand_derivatives(O::Symbolic, simplify=false) if iscall(O) && isa(operation(O), Differential) arg = only(arguments(O)) - arg = expand_derivatives(arg, false; robust) + arg = expand_derivatives(arg, false) + return executediff(operation(O), arg, simplify) elseif iscall(O) && isa(operation(O), Integral) - return operation(O)(expand_derivatives(arguments(O)[1]; robust)) + return operation(O)(expand_derivatives(arguments(O)[1])) elseif !hasderiv(O) return O else - args = map(a->expand_derivatives(a, false; robust), arguments(O)) + args = map(a->expand_derivatives(a, false), arguments(O)) O1 = operation(O)(args...) return simplify ? SymbolicUtils.simplify(O1) : O1 end - - executediff(operation(O), arg, simplify; robust, occurrences) end -function expand_derivatives(n::Num, simplify=false; robust=false, occurrences=nothing) - wrap(expand_derivatives(value(n), simplify; robust, occurrences)) +function expand_derivatives(n::Num, simplify=false) + wrap(expand_derivatives(value(n), simplify)) end -function expand_derivatives(n::Complex{Num}, simplify=false; robust=false, occurrences=nothing) - wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; robust, occurrences), - expand_derivatives(imag(n), simplify; robust, occurrences))) +function expand_derivatives(n::Complex{Num}, simplify=false) + wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify), + expand_derivatives(imag(n), simplify))) end -expand_derivatives(x, simplify=false; robust=false, occurrences=nothing) = x +expand_derivatives(x, simplify=false) = x _iszero(x) = false _isone(x) = false diff --git a/test/diff.jl b/test/diff.jl index 4a8484c1c..4bfff5f1b 100644 --- a/test/diff.jl +++ b/test/diff.jl @@ -356,9 +356,8 @@ let D = Differential(t) expr = b - ((D(b))^2) * D(D(b)) expr2 = D(expr) - @test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true)) - @test isequal(expand_derivatives(expr2; robust=true), D(b) - (D(b)^2)*D(D(D(b))) - 2D(b)*(D(D(b))^2)) - @test isequal(expand_derivatives(expr2; robust=true), expand_derivatives(expr2)) + @test isequal(expand_derivatives(expr), expr) + @test isequal(expand_derivatives(expr2), D(b) - (D(b)^2)*D(D(D(b))) - 2D(b)*(D(D(b))^2)) end # 1126 @@ -370,13 +369,13 @@ let expr_gen = (fun) -> D(D(((-D(D(fun))) / g(y)))) expr = expr_gen(g(y)) - @test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true)) + # just make sure that no errors are thrown in the following, the results are to complicated to compare + expand_derivatives(expr) expr = expr_gen(h(y)) - @test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true)) + expand_derivatives(expr) - expected = substitute(expand_derivatives(expr; robust=true), h(y) => f(y)) expr = expr_gen(f(y)) - @test isequal(expand(expand_derivatives(expr)), expand(expand_derivatives(expr; robust=true))) + expand_derivatives(expr) end # Check `is_derivative` function