Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding preliminary AMDGPU support #99

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ on:
push:
branches:
- master
types:
- opened
- reopened
- synchronize
- ready_for_review
tags: '*'
schedule:
- cron: '00 04 * * 1' # 4am every Monday
Expand All @@ -14,6 +19,7 @@ jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
if: ${{ github.event_name == 'push' || !github.event.pull_request.draft }}
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -49,6 +55,7 @@ jobs:
docs:
name: Documentation
runs-on: ubuntu-latest
if: ${{ github.event_name == 'push' || !github.event.pull_request.draft }}
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Joe G Greener <[email protected]>"]
version = "0.13.0"

[deps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
BioStructures = "de9282ab-8554-53be-b2d6-f6c222edabfc"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down Expand Up @@ -32,6 +33,7 @@ UnitfulChainRules = "f31437dd-25a7-4345-875f-756556e6935d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AMDGPU = "0.4"
AtomsBase = "0.2"
BioStructures = "1"
CUDA = "3"
Expand Down
9 changes: 9 additions & 0 deletions src/Molly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ import Chemfiles
using Colors
using Combinatorics
using CUDA
if has_cuda_gpu()
CUDA.allowscalar(false)
end

using AMDGPU
if has_rocm_gpu()
AMDGPU.allowscalar(false)
end

using DataStructures
using Distances
using Distributions
Expand Down
4 changes: 2 additions & 2 deletions src/chain_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,15 @@ function ChainRulesCore.rrule(::typeof(unsafe_getindex), arr, inds)
end

# Not faster on CPU
function ChainRulesCore.rrule(::typeof(getindices_i), arr::CuArray, neighbors)
function ChainRulesCore.rrule(::typeof(getindices_i), arr::AT, neighbors) where AT <: Union{CuArray, ROCArray}
Y = getindices_i(arr, neighbors)
@views @inbounds function getindices_i_pullback(Ȳ)
return NoTangent(), accumulate_bounds(Ȳ, neighbors.atom_bounds_i), nothing
end
return Y, getindices_i_pullback
end

function ChainRulesCore.rrule(::typeof(getindices_j), arr::CuArray, neighbors)
function ChainRulesCore.rrule(::typeof(getindices_j), arr::AT, neighbors) where AT <: Union{CuArray, ROCArray}
Y = getindices_j(arr, neighbors)
@views @inbounds function getindices_j_pullback(Ȳ)
return NoTangent(), accumulate_bounds(Ȳ[neighbors.sortperm_j], neighbors.atom_bounds_j), nothing
Expand Down
27 changes: 16 additions & 11 deletions src/gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ Allows gradients for individual parameters to be tracked.
Returns atoms, pairwise interactions, specific interaction lists and general
interactions.
"""
function inject_gradients(sys, params_dic, gpu::Bool=isa(sys.coords, CuArray))
function inject_gradients(sys, params_dic; AT = find_array_type(sys.coords),
gpu::Bool = (AT <: Union{CuArray, ROCArray}))
if gpu
atoms_grad = CuArray(inject_atom.(Array(sys.atoms), sys.atoms_data, (params_dic,)))
atoms_grad = AT(inject_atom.(Array(sys.atoms), sys.atoms_data, (params_dic,)))
else
atoms_grad = inject_atom.(sys.atoms, sys.atoms_data, (params_dic,))
end
Expand All @@ -100,7 +101,7 @@ function inject_gradients(sys, params_dic, gpu::Bool=isa(sys.coords, CuArray))
pis_grad = sys.pairwise_inters
end
if length(sys.specific_inter_lists) > 0
sis_grad = inject_interaction_list.(sys.specific_inter_lists, (params_dic,), gpu)
sis_grad = inject_interaction_list.(sys.specific_inter_lists, (params_dic,), gpu, AT)
else
sis_grad = sys.specific_inter_lists
end
Expand All @@ -127,36 +128,40 @@ function inject_atom(at, at_data, params_dic)
)
end

function inject_interaction_list(inter::InteractionList1Atoms, params_dic, gpu)
function inject_interaction_list(inter::InteractionList1Atoms, params_dic, gpu,
AT)
if gpu
inters_grad = CuArray(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
inters_grad = AT(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
else
inters_grad = inject_interaction.(inter.inters, inter.types, (params_dic,))
end
InteractionList1Atoms(inter.is, inter.types, inters_grad)
end

function inject_interaction_list(inter::InteractionList2Atoms, params_dic, gpu)
function inject_interaction_list(inter::InteractionList2Atoms, params_dic, gpu,
AT)
if gpu
inters_grad = CuArray(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
inters_grad = AT(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
else
inters_grad = inject_interaction.(inter.inters, inter.types, (params_dic,))
end
InteractionList2Atoms(inter.is, inter.js, inter.types, inters_grad)
end

function inject_interaction_list(inter::InteractionList3Atoms, params_dic, gpu)
function inject_interaction_list(inter::InteractionList3Atoms, params_dic, gpu,
AT)
if gpu
inters_grad = CuArray(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
inters_grad = AT(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
else
inters_grad = inject_interaction.(inter.inters, inter.types, (params_dic,))
end
InteractionList3Atoms(inter.is, inter.js, inter.ks, inter.types, inters_grad)
end

function inject_interaction_list(inter::InteractionList4Atoms, params_dic, gpu)
function inject_interaction_list(inter::InteractionList4Atoms, params_dic, gpu,
AT)
if gpu
inters_grad = CuArray(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
inters_grad = AT(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
else
inters_grad = inject_interaction.(inter.inters, inter.types, (params_dic,))
end
Expand Down
10 changes: 10 additions & 0 deletions src/interactions/implicit_solvent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,10 @@ function ImplicitSolventOBC(atoms::AbstractArray{Atom{T, M, D, E}},
or = CuArray(offset_radii)
sor = CuArray(scaled_offset_radii)
is, js = CuArray(inds_i), CuArray(inds_j)
elseif isa(atoms, ROCArray)
or = ROCArray(offset_radii)
sor = ROCArray(scaled_offset_radii)
is, js = ROCArray(inds_i), ROCArrayArray(inds_j)
else
or = offset_radii
sor = scaled_offset_radii
Expand Down Expand Up @@ -555,6 +559,12 @@ function ImplicitSolventGBN2(atoms::AbstractArray{Atom{T, M, D, E}},
is, js = CuArray(inds_i), CuArray(inds_j)
d0s, m0s = CuArray(table_d0), CuArray(table_m0)
αs, βs, γs = CuArray(αs_cpu), CuArray(βs_cpu), CuArray(γs_cpu)
elseif isa(atoms, ROCArray)
or = ROCArray(offset_radii)
sor = ROCArray(scaled_offset_radii)
is, js = ROCArray(inds_i), ROCArray(inds_j)
d0s, m0s = ROCArray(table_d0), ROCArray(table_m0)
αs, βs, γs = ROCArray(αs_cpu), ROCArray(βs_cpu), ROCArray(γs_cpu)
else
or = offset_radii
sor = scaled_offset_radii
Expand Down
4 changes: 4 additions & 0 deletions src/neighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ function DistanceVecNeighborFinder(;
is = CuArray(hcat([collect(1:n_atoms) for i in 1:n_atoms]...))
js = CuArray(permutedims(is, (2, 1)))
m14 = CuArray(matrix_14)
elseif isa(nb_matrix, ROCArray)
is = ROCArray(hcat([collect(1:n_atoms) for i in 1:n_atoms]...))
js = ROCArray(permutedims(is, (2, 1)))
m14 = ROCArray(matrix_14)
else
is = hcat([collect(1:n_atoms) for i in 1:n_atoms]...)
js = permutedims(is, (2, 1))
Expand Down
57 changes: 38 additions & 19 deletions src/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,24 @@ export
is_heavy_atom,
add_position_restraints

# Creating default Array Type (AT) for users who did not specify
function configure_array_type(gpu)
if !gpu
AT = Array
elseif has_rocm_gpu() && has_cuda_gpu()
@warn("Both AMD and NVIDIA gpus available!\n"*
"Defaulting to CuArray...\n"*
"If you would like to use your AMD GPU, please specify " *
"System(...; AT = ROCArray)")
AT = CuArray
elseif has_cuda_gpu()
AT = CuArray
elseif has_rocm_gpu()
AT = ROCArray
end
return AT
end

"""
place_atoms(n_atoms, boundary; min_dist=nothing, max_attempts=100)

Expand Down Expand Up @@ -372,7 +390,8 @@ function System(coord_file::AbstractString,
implicit_solvent=nothing,
center_coords::Bool=true,
rename_terminal_res::Bool=true,
kappa=0.0u"nm^-1")
kappa=0.0u"nm^-1",
AT = configure_array_type(gpu))
T = typeof(force_field.weight_14_coulomb)

# Chemfiles uses zero-based indexing, be careful
Expand Down Expand Up @@ -721,26 +740,25 @@ function System(coord_file::AbstractString,
specific_inter_array = []
if length(bonds.is) > 0
push!(specific_inter_array, InteractionList2Atoms(
bonds.is, bonds.js, bonds.types,
gpu ? CuArray([bonds.inters...]) : [bonds.inters...],
bonds.is, bonds.js, bonds.types, AT([bonds.inters...]),
))
end
if length(angles.is) > 0
push!(specific_inter_array, InteractionList3Atoms(
angles.is, angles.js, angles.ks, angles.types,
gpu ? CuArray([angles.inters...]) : [angles.inters...],
AT([angles.inters...]),
))
end
if length(torsions.is) > 0
push!(specific_inter_array, InteractionList4Atoms(
torsions.is, torsions.js, torsions.ks, torsions.ls, torsions.types,
gpu ? CuArray(torsion_inters_pad) : torsion_inters_pad,
AT(torsion_inters_pad),
))
end
if length(impropers.is) > 0
push!(specific_inter_array, InteractionList4Atoms(
impropers.is, impropers.js, impropers.ks, impropers.ls, impropers.types,
gpu ? CuArray(improper_inters_pad) : improper_inters_pad,
AT(improper_inters_pad),
))
end
specific_inter_lists = tuple(specific_inter_array...)
Expand Down Expand Up @@ -771,8 +789,8 @@ function System(coord_file::AbstractString,
atoms = [atoms...]
if gpu_diff_safe
neighbor_finder = DistanceVecNeighborFinder(
nb_matrix=gpu ? CuArray(nb_matrix) : nb_matrix,
matrix_14=gpu ? CuArray(matrix_14) : matrix_14,
nb_matrix=AT(nb_matrix),
matrix_14=AT(matrix_14),
n_steps=10,
dist_cutoff=T(dist_neighbors),
)
Expand All @@ -787,8 +805,8 @@ function System(coord_file::AbstractString,
)
end
if gpu
atoms = CuArray(atoms)
coords = CuArray(coords)
atoms = AT(atoms)
coords = AT(coords)
end

if isnothing(velocities)
Expand Down Expand Up @@ -845,7 +863,9 @@ function System(T::Type,
gpu_diff_safe::Bool=gpu,
dist_cutoff=units ? 1.0u"nm" : 1.0,
dist_neighbors=units ? 1.2u"nm" : 1.2,
center_coords::Bool=true)
center_coords::Bool=true,
AT = configure_array_type(gpu))

# Read force field and topology file
atomtypes = Dict{String, Atom}()
bondtypes = Dict{String, HarmonicBond}()
Expand Down Expand Up @@ -1108,20 +1128,19 @@ function System(T::Type,
specific_inter_array = []
if length(bonds.is) > 0
push!(specific_inter_array, InteractionList2Atoms(
bonds.is, bonds.js, bonds.types,
gpu ? CuArray([bonds.inters...]) : [bonds.inters...],
bonds.is, bonds.js, bonds.types, AT([bonds.inters...]),
))
end
if length(angles.is) > 0
push!(specific_inter_array, InteractionList3Atoms(
angles.is, angles.js, angles.ks, angles.types,
gpu ? CuArray([angles.inters...]) : [angles.inters...],
AT([angles.inters...]),
))
end
if length(torsions.is) > 0
push!(specific_inter_array, InteractionList4Atoms(
torsions.is, torsions.js, torsions.ks, torsions.ls, torsions.types,
gpu ? CuArray([torsions.inters...]) : [torsions.inters...],
AT([torsions.inters...]),
))
end
specific_inter_lists = tuple(specific_inter_array...)
Expand All @@ -1130,8 +1149,8 @@ function System(T::Type,

if gpu_diff_safe
neighbor_finder = DistanceVecNeighborFinder(
nb_matrix=gpu ? CuArray(nb_matrix) : nb_matrix,
matrix_14=gpu ? CuArray(matrix_14) : matrix_14,
nb_matrix=AT(nb_matrix),
matrix_14=AT(matrix_14),
n_steps=10,
dist_cutoff=T(dist_neighbors),
)
Expand All @@ -1146,8 +1165,8 @@ function System(T::Type,
)
end
if gpu
atoms = CuArray(atoms)
coords = CuArray(coords)
atoms = AT(atoms)
coords = AT(coords)
end

if isnothing(velocities)
Expand Down
Loading