Skip to content

Commit

Permalink
Add test_throws check for wrong user gradient assertion
Browse files Browse the repository at this point in the history
  • Loading branch information
KnutAM committed Mar 25, 2022
1 parent 42ab791 commit f151302
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion test/test_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,16 @@ S(C) = S(C, μ, Kb)
@test gradient(ts_ts_ts, xs) gradient(ts_ts_ts_ana, xs)

end


# Test that AssertionError is thrown for erroneous user functions
test_self(x) = x
test_self_wronggradient(x) = test_self(x), one(x) # Only true for scalars, not for tensors
@implement_gradient test_self test_self_wronggradient
@test gradient(test_self, rand()) 1.0 # Should work fine
@test_throws AssertionError gradient(test_self, rand(Tensor{2,3}))
@test_throws AssertionError gradient(test_self, rand(SymmetricTensor{2,3}))
end


end # testsection

0 comments on commit f151302

Please sign in to comment.