diff --git a/src/diff.jl b/src/diff.jl index d9d4c1bbf..891e2cde2 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -371,6 +371,9 @@ derivative(::typeof(+), args::NTuple{N,Any}, ::Val) where {N} = 1 derivative(::typeof(*), args::NTuple{N,Any}, ::Val{i}) where {N,i} = *(deleteat!(collect(args), i)...) derivative(::typeof(one), args::Tuple{<:Any}, ::Val) = 0 +derivative(f::Function, x::Num) = derivative(f(x), x) +derivative(::Function, x::Any) = TypeError(:derivative, "2nd argument", Num, typeof(x)) |> throw + function count_order(x) @assert !(x isa Symbol) "The variable $x must have an order of differentiation that is greater or equal to 1!" n = 1 diff --git a/test/diff.jl b/test/diff.jl index 88956f138..eb500cb9d 100644 --- a/test/diff.jl +++ b/test/diff.jl @@ -378,4 +378,23 @@ let Dt = Differential(t)^0 @test isequal(Dt, identity) test_equal(Dt(t + 2t^2), t + 2t^2) -end \ No newline at end of file +end + +# Check `Function` inputs for derivative (#1085) +let + @variables x + @testset for f in [sqrt, sin, acos, exp, cis] + @test isequal( + Symbolics.derivative(f, x), + Symbolics.derivative(f(x), x) + ) + end +end + +# Check `Function` inputs throw for non-Num second input (#1085) +let + @testset for f in [sqrt, sin, acos, exp, cis] + @test_throws TypeError Symbolics.derivative(f, rand()) + @test_throws TypeError Symbolics.derivative(f, Val(rand(Int))) + end +end