Skip to content

Commit

Permalink
Final cosmetics, fixes Ferrite-FEM#179
Browse files Browse the repository at this point in the history
  • Loading branch information
KnutAM committed Mar 26, 2022
1 parent f151302 commit c46ae78
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 2 additions & 1 deletion docs/src/man/automatic_differentiation.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ The method relies on multiple dispatch to run your gradient function instead of
the calling the regular function with dual numbers. Julia will always prefer the
most specific type definition, but it can sometimes be hard to know which is most
specific. Therefore, it is always recommended to test that your gradient function
is called when testing, by e.g. inserting a print statement at the beginning.
is called when testing, by e.g. inserting a print statement at the beginning as
in the example below.

### Example
Lets consider the function ``h(\mathbf{f}(\mathbf{g}(\mathbf{x})))``
Expand Down
4 changes: 1 addition & 3 deletions src/automatic_differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ type `Tensors.Dual` (which is equivalent to `ForwardDiff.Dual`)
"""
function propagate_gradient(f_dfdx::Function, x::Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual})
fval, dfdx_val = f_dfdx(_extract_value(x))
_check_gradient_shape(fval,x,dfdx_val)
return _insert_gradient(fval, dfdx_val, x)
end

Expand All @@ -295,21 +296,18 @@ end
"""
function _insert_gradient(f::Union{Number,AbstractTensor}, dfdg::Union{Number,AbstractTensor}, g::Dual{Tg}) where{Tg}
_check_gradient_shape(f,g,dfdg)
dgdx = _extract_gradient(g, _get_original_gradient_input(g))
dfdx = dfdg dgdx
return _insert_full_gradient(f, dfdx, Tg())
end

function _insert_gradient(f::Union{Number,AbstractTensor}, dfdg::Union{Number,AbstractTensor}, g::Vec{<:Any, <:Dual{Tg}}) where{Tg}
_check_gradient_shape(f,g,dfdg)
dgdx = _extract_gradient(g, _get_original_gradient_input(g))
dfdx = dfdg dgdx
return _insert_full_gradient(f, dfdx, Tg())
end

function _insert_gradient(f::Union{Number,AbstractTensor}, dfdg::Union{Number,AbstractTensor}, g::SecondOrderTensor{<:Any,<:Dual{Tg}}) where{Tg}
_check_gradient_shape(f,g,dfdg)
dgdx = _extract_gradient(g, _get_original_gradient_input(g))
dfdx = dfdg dgdx
return _insert_full_gradient(f, dfdx, Tg())
Expand Down

0 comments on commit c46ae78

Please sign in to comment.