Skip to content

Commit

Permalink
I think the tests pass now?
Browse files Browse the repository at this point in the history
  • Loading branch information
leios committed Dec 19, 2024
1 parent cac4be1 commit cbf6f56
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 30 deletions.
12 changes: 6 additions & 6 deletions src/interactions/implicit_solvent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ function ImplicitSolventGBN2(atoms::AbstractArray{Atom{TY, M, T, D, E}},
end

if isa(atoms, AbstractGPUArray)
ArrayType = fine_array_type(atoms)
ArrayType = get_array_type(atoms)
or = ArrayType(offset_radii)
sor = ArrayType(scaled_offset_radii)
is, js = ArrayType(inds_i), ArrayType(inds_j)
Expand Down Expand Up @@ -948,8 +948,8 @@ function forces_gbsa(sys, inter, Bs, B_grads, I_grads, born_forces, atom_charges
return fs
end

function forces_gbsa(sys::System{D, true, T}, inter, Bs, B_grads, I_grads, born_forces,
atom_charges) where {D, T}
function forces_gbsa(sys::System{D, AT, T}, inter, Bs, B_grads, I_grads, born_forces,
atom_charges) where {D, AT <: AbstractGPUArray, T}
fs_mat_1, born_forces_mod_ustrip = gbsa_force_1_gpu(sys.coords, sys.boundary, inter.dist_cutoff,
inter.factor_solute, inter.factor_solvent, inter.kappa, Bs, atom_charges,
sys.force_units)
Expand Down Expand Up @@ -1149,8 +1149,8 @@ function gb_energy_loop(coord_i, coord_j, i, j, charge_i, charge_j, Bi, Bj, ori,
end
end

function AtomsCalculators.potential_energy(sys::System{<:Any, false, T}, inter::AbstractGBSA;
kwargs...) where T
function AtomsCalculators.potential_energy(sys::System{<:Any, AT, T}, inter::AbstractGBSA;
kwargs...) where {AT, T}
coords, boundary = sys.coords, sys.boundary
Bs, B_grads, I_grads = born_radii_and_grad(inter, coords, boundary)
atom_charges = charge.(sys.atoms)
Expand All @@ -1169,7 +1169,7 @@ function AtomsCalculators.potential_energy(sys::System{<:Any, false, T}, inter::
return E
end

function AtomsCalculators.potential_energy(sys::System{<:Any, true}, inter::AbstractGBSA; kwargs...)
function AtomsCalculators.potential_energy(sys::System{<:Any, AT}, inter::AbstractGBSA; kwargs...) where AT <: AbstractGPUArray
coords, atoms, boundary = sys.coords, sys.atoms, sys.boundary
Bs, B_grads, I_grads = born_radii_and_grad(inter, coords, boundary)

Expand Down
57 changes: 33 additions & 24 deletions test/gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,31 @@ end

@testset "Differentiable simulation" begin
runs = [ # gpu par fwd f32 obc2 gbn2 tol_σ tol_r0
("CPU" , false, false, false, false, false, false, 1e-4, 1e-4),
("CPU forward" , false, false, true , false, false, false, 0.5 , 0.1 ),
("CPU f32" , false, false, false, true , false, false, 0.01, 5e-4),
("CPU obc2" , false, false, false, false, true , false, 1e-4, 1e-4),
("CPU gbn2" , false, false, false, false, false, true , 1e-4, 1e-4),
("CPU gbn2 forward", false, false, true , false, false, true , 0.5 , 0.1 ),
("CPU" , Array, false, false, false, false, false, 1e-4, 1e-4),
("CPU forward" , Array, false, true , false, false, false, 0.5 , 0.1 ),
("CPU f32" , Array, false, false, true , false, false, 0.01, 5e-4),
("CPU obc2" , Array, false, false, false, true , false, 1e-4, 1e-4),
("CPU gbn2" , Array, false, false, false, false, true , 1e-4, 1e-4),
("CPU gbn2 forward", Array, false, true , false, false, true , 0.5 , 0.1 ),
]
if run_parallel_tests # gpu par fwd f32 obc2 gbn2 tol_σ tol_r0
push!(runs, ("CPU parallel" , false, true , false, false, false, false, 1e-4, 1e-4))
push!(runs, ("CPU parallel forward", false, true , true , false, false, false, 0.5 , 0.1 ))
push!(runs, ("CPU parallel f32" , false, true , false, true , false, false, 0.01, 5e-4))
push!(runs, ("CPU parallel" , Array, true , false, false, false, false, 1e-4, 1e-4))
push!(runs, ("CPU parallel forward", Array, true , true , false, false, false, 0.5 , 0.1 ))
push!(runs, ("CPU parallel f32" , Array, true , false, true , false, false, 0.01, 5e-4))
end
if run_gpu_tests # gpu par fwd f32 obc2 gbn2 tol_σ tol_r0
push!(runs, ("GPU" , true , false, false, false, false, false, 0.25, 20.0))
push!(runs, ("GPU forward" , true , false, true , false, false, false, 0.25, 20.0))
push!(runs, ("GPU f32" , true , false, false, true , false, false, 0.5 , 50.0))
push!(runs, ("GPU obc2" , true , false, false, false, true , false, 0.25, 20.0))
push!(runs, ("GPU gbn2" , true , false, false, false, false, true , 0.25, 20.0))
if run_cuda_tests # gpu par fwd f32 obc2 gbn2 tol_σ tol_r0
push!(runs, ("CUDA" , CuArray, false, false, false, false, false, 0.25, 20.0))
push!(runs, ("CUDA forward" , CuArray, false, true , false, false, false, 0.25, 20.0))
push!(runs, ("CUDA f32" , CuArray, false, false, true , false, false, 0.5 , 50.0))
push!(runs, ("CUDA obc2" , CuArray, false, false, false, true , false, 0.25, 20.0))
push!(runs, ("CUDA gbn2" , CuArray, false, false, false, false, true , 0.25, 20.0))
end
if run_rocm_tests # gpu par fwd f32 obc2 gbn2 tol_σ tol_r0
push!(runs, ("ROCM" , ROCArray, false, false, false, false, false, 0.25, 20.0))
push!(runs, ("ROCM forward" , ROCArray, false, true , false, false, false, 0.25, 20.0))
push!(runs, ("ROCM f32" , ROCArray, false, false, true , false, false, 0.5 , 50.0))
push!(runs, ("ROCM obc2" , ROCArray, false, false, false, true , false, 0.25, 20.0))
push!(runs, ("ROCM gbn2" , ROCArray, false, false, false, false, true , 0.25, 20.0))
end

function mean_min_separation(coords, boundary, ::Val{T}) where T
Expand Down Expand Up @@ -103,9 +110,8 @@ end
return mean_min_separation(sys.coords, boundary, Val(T))
end

for (name, gpu, parallel, forward, f32, obc2, gbn2, tol_σ, tol_r0) in runs
for (name, AT, parallel, forward, f32, obc2, gbn2, tol_σ, tol_r0) in runs
T = f32 ? Float32 : Float64
AT = gpu ? CuArray : Array
σ = T(0.4)
r0 = T(1.0)
n_atoms = 50
Expand Down Expand Up @@ -245,13 +251,13 @@ end
end

@testset "Differentiable protein" begin
function create_sys(gpu::Bool)
function create_sys(ArrayType)
ff = MolecularForceField(joinpath.(ff_dir, ["ff99SBildn.xml", "his.xml"])...; units=false)
return System(
joinpath(data_dir, "6mrr_nowater.pdb"),
ff;
units=false,
gpu=gpu,
ArrayType=ArrayType,
implicit_solvent="gbn2",
kappa=0.7,
)
Expand Down Expand Up @@ -402,10 +408,13 @@ end

platform_runs = [("CPU", false, false)]
if run_parallel_tests
push!(platform_runs, ("CPU parallel", false, true))
push!(platform_runs, ("CPU parallel", Array, true))
end
if run_cuda_tests
push!(platform_runs, ("CUDA", CuArray, false))
end
if run_gpu_tests
push!(platform_runs, ("GPU", true, false))
if run_rocm_tests
push!(platform_runs, ("ROCM", ROCArray, false))
end
test_runs = [
("Energy", test_energy_grad, 1e-8),
Expand All @@ -423,8 +432,8 @@ end
)

for (test_name, test_fn, test_tol) in test_runs
for (platform, gpu, parallel) in platform_runs
sys_ref = create_sys(gpu)
for (platform, AT, parallel) in platform_runs
sys_ref = create_sys(AT)
n_threads = parallel ? Threads.nthreads() : 1
grads_enzyme = Dict(k => 0.0 for k in keys(params_dic))
autodiff(
Expand Down

0 comments on commit cbf6f56

Please sign in to comment.