diff --git a/docs/src/man/automatic_differentiation.md b/docs/src/man/automatic_differentiation.md index 3fe67c1f..41a07b8c 100644 --- a/docs/src/man/automatic_differentiation.md +++ b/docs/src/man/automatic_differentiation.md @@ -116,7 +116,8 @@ The method relies on multiple dispatch to run your gradient function instead of the calling the regular function with dual numbers. Julia will always prefer the most specific type definition, but it can sometimes be hard to know which is most specific. Therefore, it is always recommended to test that your gradient function -is called when testing, by e.g. inserting a print statement at the beginning. +is called when testing, by e.g. inserting a print statement at the beginning as +in the example below. ### Example Lets consider the function ``h(\mathbf{f}(\mathbf{g}(\mathbf{x})))`` diff --git a/src/automatic_differentiation.jl b/src/automatic_differentiation.jl index e1c66ca1..7a7cf583 100644 --- a/src/automatic_differentiation.jl +++ b/src/automatic_differentiation.jl @@ -269,6 +269,7 @@ type `Tensors.Dual` (which is equivalent to `ForwardDiff.Dual`) """ function propagate_gradient(f_dfdx::Function, x::Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual}) fval, dfdx_val = f_dfdx(_extract_value(x)) + _check_gradient_shape(fval,x,dfdx_val) return _insert_gradient(fval, dfdx_val, x) end @@ -295,21 +296,18 @@ end """ function _insert_gradient(f::Union{Number,AbstractTensor}, dfdg::Union{Number,AbstractTensor}, g::Dual{Tg}) where{Tg} - _check_gradient_shape(f,g,dfdg) dgdx = _extract_gradient(g, _get_original_gradient_input(g)) dfdx = dfdg ⊗ dgdx return _insert_full_gradient(f, dfdx, Tg()) end function _insert_gradient(f::Union{Number,AbstractTensor}, dfdg::Union{Number,AbstractTensor}, g::Vec{<:Any, <:Dual{Tg}}) where{Tg} - _check_gradient_shape(f,g,dfdg) dgdx = _extract_gradient(g, _get_original_gradient_input(g)) dfdx = dfdg ⋅ dgdx return _insert_full_gradient(f, dfdx, Tg()) end function _insert_gradient(f::Union{Number,AbstractTensor}, dfdg::Union{Number,AbstractTensor}, g::SecondOrderTensor{<:Any,<:Dual{Tg}}) where{Tg} - _check_gradient_shape(f,g,dfdg) dgdx = _extract_gradient(g, _get_original_gradient_input(g)) dfdx = dfdg ⊡ dgdx return _insert_full_gradient(f, dfdx, Tg())