Skip to content

Commit

Permalink
StaticArrays gradient changes
Browse files Browse the repository at this point in the history
jgreener64 committed Dec 18, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 8e2a527 commit e1c6325
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 0 additions & 8 deletions src/chain_rules.jl
Original file line number Diff line number Diff line change
@@ -21,14 +21,6 @@
@non_differentiable System(T::Type, coord_file::AbstractString, top_file::AbstractString)
@non_differentiable System(coord_file::AbstractString, top_file::AbstractString)

function ChainRulesCore.rrule(T::Type{<:SVector}, vs::Number...)
Y = T(vs...)
function SVector_pullback(Ȳ)
return NoTangent(), Ȳ...
end
return Y, SVector_pullback
end

function ChainRulesCore.rrule(T::Type{<:Atom}, vs...)
Y = T(vs...)
function Atom_pullback(Ȳ)
8 changes: 8 additions & 0 deletions src/zygote.jl
Original file line number Diff line number Diff line change
@@ -81,6 +81,14 @@ function Base.:+(y::NamedTuple{(:atoms, :coords, :boundary,
return r + y
end

function Zygote.accum(x::NamedTuple{(:side_lengths,), Tuple{SVector{3, T}}}, y::SVector{3, T}) where T
CubicBoundary(x.side_lengths .+ y; check_positive=false)
end

function Zygote.accum(x::NamedTuple{(:side_lengths,), Tuple{SVector{2, T}}}, y::SVector{2, T}) where T
RectangularBoundary(x.side_lengths .+ y; check_positive=false)
end

function Zygote.accum(x::NamedTuple{(:side_lengths,), Tuple{SizedVector{3, T, Vector{T}}}}, y::SVector{3, T}) where T
CubicBoundary(SVector{3, T}(x.side_lengths .+ y); check_positive=false)
end

0 comments on commit e1c6325

Please sign in to comment.