Skip to content

Commit

Permalink
Fix ForwardDiff derivative of NaNMath.pow
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Jan 9, 2025
1 parent eb3b5f6 commit 185c012
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
6 changes: 3 additions & 3 deletions ext/SymbolicsForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ end
# exponentiation #
#----------------#

for f in (:(Base.:^), :(NaNMath.pow))
for (f, log) in ((:(Base.:^), :(Base.log)), (:(NaNMath.pow), :(NaNMath.log)))
@eval begin
@define_binary_dual_op(
$f,
Expand All @@ -212,7 +212,7 @@ for f in (:(Base.:^), :(NaNMath.pow))
elseif iszero(vx) && vy > 0
logval = zero(vx)
else
logval = expv * log(vx)
logval = expv * ($log)(vx)
end
new_partials = _mul_partials(partials(x), partials(y), powval, logval)
return Dual{Txy}(expv, new_partials)
Expand All @@ -230,7 +230,7 @@ for f in (:(Base.:^), :(NaNMath.pow))
begin
v = value(y)
expv = ($f)(x, v)
deriv = (iszero(x) && v > 0) ? zero(expv) : expv*log(x)
deriv = (iszero(x) && v > 0) ? zero(expv) : expv*($log)(x)
return Dual{Ty}(expv, deriv * partials(y))
end,
$AMBIGUOUS_TYPES
Expand Down
6 changes: 6 additions & 0 deletions test/forwarddiff_symbolic_dual_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,9 @@ end
y(x) = isequal(z, x) ? 0 : x
@test ForwardDiff.derivative(y, 0) == 1 # expect ∂(x)/∂x
end

@testset "NaNMath.pow (issue #1399)" begin
@variables x
@test_throws DomainError substitute(ForwardDiff.derivative(z -> x^z, 0.5), x => -1.0)
@test isnan(Symbolics.value(substitute(ForwardDiff.derivative(z -> NaNMath.pow(x, z), 0.5), x => -1.0)))
end

0 comments on commit 185c012

Please sign in to comment.