From fcdd072632abd286431c6f3cf8ec5aaa418c8ed5 Mon Sep 17 00:00:00 2001 From: James Schloss Date: Wed, 18 Dec 2024 11:29:34 +0100 Subject: [PATCH] messing around with extension loading and a few typos --- Project.toml | 4 +++- ext/MollyCUDAEnzymeExt.jl | 13 +++++++++++++ ext/MollyCUDAExt.jl | 1 - ext/MollyEnzymeExt.jl | 3 --- src/force.jl | 4 ++-- src/setup.jl | 4 ++-- test/Project.toml | 1 + test/runtests.jl | 12 +----------- 8 files changed, 22 insertions(+), 20 deletions(-) create mode 100644 ext/MollyCUDAEnzymeExt.jl diff --git a/Project.toml b/Project.toml index 262fa0b2..b13a118e 100644 --- a/Project.toml +++ b/Project.toml @@ -41,8 +41,9 @@ KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" [extensions] -MollyEnzymeExt = "Enzyme" MollyCUDAExt = "CUDA" +MollyEnzymeExt = "Enzyme" +MollyCUDAEnzymeExt = ["CUDA", "Enzyme"] MollyGLMakieExt = ["GLMakie", "Colors"] MollyKernelDensityExt = "KernelDensity" MollyPythonCallExt = "PythonCall" @@ -55,6 +56,7 @@ BioStructures = "4" CUDA = "4.2, 5" CellListMap = "0.8.11, 0.9" Chemfiles = "0.10.3" +ChainRulesCore = "1.25.0" Colors = "0.11, 0.12, 0.13" Combinatorics = "1" DataStructures = "0.18" diff --git a/ext/MollyCUDAEnzymeExt.jl b/ext/MollyCUDAEnzymeExt.jl new file mode 100644 index 00000000..c88ebd14 --- /dev/null +++ b/ext/MollyCUDAEnzymeExt.jl @@ -0,0 +1,13 @@ +module MollyCUDAEnzymeExt + +using Molly +using CUDA +using Enzyme + +ext = Base.get_extension(Molly,:MollyCUDAExt) + +EnzymeRules.inactive(::typeof(ext.cuda_threads_blocks_pairwise), args...) = nothing +EnzymeRules.inactive(::typeof(ext.cuda_threads_blocks_specific), args...) = nothing + + +end diff --git a/ext/MollyCUDAExt.jl b/ext/MollyCUDAExt.jl index 90ad4897..c0ed7b73 100644 --- a/ext/MollyCUDAExt.jl +++ b/ext/MollyCUDAExt.jl @@ -2,7 +2,6 @@ module MollyCUDAExt using Molly using CUDA -using ChainRulesCore using Atomix CUDA.Const(nl::Molly.NoNeighborList) = nl diff --git a/ext/MollyEnzymeExt.jl b/ext/MollyEnzymeExt.jl index 90e01539..26fd0e88 100644 --- a/ext/MollyEnzymeExt.jl +++ b/ext/MollyEnzymeExt.jl @@ -11,13 +11,10 @@ EnzymeRules.inactive(::typeof(Molly.n_infinite_dims), args...) = nothing EnzymeRules.inactive(::typeof(random_velocity), args...) = nothing EnzymeRules.inactive(::typeof(random_velocities), args...) = nothing EnzymeRules.inactive(::typeof(random_velocities!), args...) = nothing -EnzymeRules.inactive(::typeof(Molly.cuda_threads_blocks_pairwise), args...) = nothing -EnzymeRules.inactive(::typeof(Molly.cuda_threads_blocks_specific), args...) = nothing EnzymeRules.inactive(::typeof(Molly.check_force_units), args...) = nothing EnzymeRules.inactive(::typeof(Molly.check_energy_units), args...) = nothing EnzymeRules.inactive(::typeof(Molly.atoms_bonded_to_N), args...) = nothing EnzymeRules.inactive(::typeof(Molly.lookup_table), args...) = nothing -EnzymeRules.inactive(::typeof(Molly.cuda_threads_blocks_gbsa), args...) = nothing EnzymeRules.inactive(::typeof(find_neighbors), args...) = nothing EnzymeRules.inactive_type(::Type{DistanceNeighborFinder}) = nothing EnzymeRules.inactive(::typeof(visualize), args...) = nothing diff --git a/src/force.jl b/src/force.jl index 52bad345..c22a20ff 100644 --- a/src/force.jl +++ b/src/force.jl @@ -145,8 +145,8 @@ function forces(sys, neighbors, step_n::Integer=0; n_threads::Integer=Threads.nt return forces_nounits .* sys.force_units end -function forces_nounits!(fs_nounits, sys::System{D, false}, neighbors, fs_chunks=nothing, - step_n::Integer=0; n_threads::Integer=Threads.nthreads()) where D +function forces_nounits!(fs_nounits, sys::System{D, AT}, neighbors, fs_chunks=nothing, + step_n::Integer=0; n_threads::Integer=Threads.nthreads()) where {D, AT <: AbstractArray} pairwise_inters_nonl = filter(!use_neighbors, values(sys.pairwise_inters)) pairwise_inters_nl = filter( use_neighbors, values(sys.pairwise_inters)) sils_1_atoms = filter(il -> il isa InteractionList1Atoms, values(sys.specific_inter_lists)) diff --git a/src/setup.jl b/src/setup.jl index 7a519b5b..8a81a58d 100644 --- a/src/setup.jl +++ b/src/setup.jl @@ -905,8 +905,8 @@ function System(coord_file::AbstractString, ) end - atoms = ArrayType(atoms) - coords = ArrayType(coords) + atoms = ArrayType([atoms_abst...]) + coords_dev = ArrayType(coords) if isnothing(velocities) if units diff --git a/test/Project.toml b/test/Project.toml index e7e9a459..69fec660 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,6 +10,7 @@ DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/runtests.jl b/test/runtests.jl index 3be6418e..73a6fbae 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -54,15 +54,6 @@ end # Allow CUDA device to be specified const DEVICE = parse(Int, get(ENV, "DEVICE", "0")) -<<<<<<< HEAD -const run_gpu_tests = get(ENV, "GPUTESTS", "1") != "0" && CUDA.functional() -const gpu_list = (run_gpu_tests ? (false, true) : (false,)) -if run_gpu_tests - device!(DEVICE) - @info "The GPU tests will be run on device $DEVICE" -elseif get(ENV, "GPUTESTS", "1") == "0" - @warn "The GPU tests will not be run as GPUTESTS is set to 0" -======= const run_cuda_tests = get(ENV, "GPUTESTS", "1") != "0" && CUDA.functional() const run_rocm_tests = get(ENV, "GPUTESTS", "1") != "0" && AMDGPU.functional() @@ -72,14 +63,13 @@ if run_cuda_tests array_list = (array_list..., CuArray) device!(parse(Int, DEVICE)) @info "The CUDA tests will be run on device $DEVICE" ->>>>>>> c820f41f (Adding KernelAbstractions tooling for Molly and tests) else @warn "The CUDA tests will not be run as a CUDA-enabled device is not available" end if run_rocm_tests array_list = (array_list..., ROCArray) - AMDGPU.device!(AMDGPU.devices()[parse(Int, DEVICE)+1]) + AMDGPU.device!(AMDGPU.device(DEVICE+1)) @info "The ROCM tests will be run on device $DEVICE" else @warn "The ROCM tests will not be run as a ROCM-enabled device is not available"