Skip to content

Commit

Permalink
stable gradient tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jgreener64 committed Dec 20, 2024
1 parent 7283638 commit e1f6a1a
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions test/gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ end

function loss(σ, r0, coords, velocities, boundary, pairwise_inters, general_inters,
neighbor_finder, simulator, n_steps, n_threads, n_atoms, atom_mass, bond_dists,
bond_is, bond_js, angles, torsions, ::Val{T}, ::Val{AT}) where {T, AT}
bond_is, bond_js, angles, torsions, rng, ::Val{T}, ::Val{AT}) where {T, AT}
atoms = [Atom(i, 1, atom_mass, (i % 2 == 0 ? T(-0.02) : T(0.02)), σ, T(0.2)) for i in 1:n_atoms]
bonds_inner = HarmonicBond{T, T}[]
for i in 1:(n_atoms ÷ 2)
Expand All @@ -98,7 +98,7 @@ end
energy_units=NoUnits,
)

simulate!(sys, simulator, n_steps; n_threads=n_threads)
simulate!(sys, simulator, n_steps; n_threads=n_threads, rng=rng)

return mean_min_separation(sys.coords, boundary, Val(T))
end
Expand Down Expand Up @@ -186,7 +186,7 @@ end
Const(general_inters), Const(neighbor_finder), Const(simulator),
Const(n_steps), Const(n_threads), Const(n_atoms), Const(atom_mass),
Const(bond_dists), Const(bond_is), Const(bond_js), Const(angles),
Const(torsions), Const(Val(T)), Const(Val(AT)),
Const(torsions), Const(rng), Const(Val(T)), Const(Val(AT)),
]
if forward
grad_enzyme = (
Expand Down Expand Up @@ -214,15 +214,15 @@ end
σ -> loss(
σ, r0, copy(coords), copy(velocities), boundary, pairwise_inters, general_inters,
neighbor_finder, simulator, n_steps, n_threads, n_atoms, atom_mass, bond_dists,
bond_is, bond_js, angles, torsions, Val(T), Val(AT),
bond_is, bond_js, angles, torsions, rng, Val(T), Val(AT),
),
σ,
),
central_fdm(6, 1)(
r0 -> loss(
σ, r0, copy(coords), copy(velocities), boundary, pairwise_inters, general_inters,
neighbor_finder, simulator, n_steps, n_threads, n_atoms, atom_mass, bond_dists,
bond_is, bond_js, angles, torsions, Val(T), Val(AT),
bond_is, bond_js, angles, torsions, rng, Val(T), Val(AT),
),
r0,
),
Expand Down

0 comments on commit e1f6a1a

Please sign in to comment.