Skip to content

Commit

Permalink
ArrayType -> array_type
Browse files Browse the repository at this point in the history
  • Loading branch information
leios committed Dec 23, 2024
1 parent 84bb0c8 commit 5c1b912
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 118 deletions.
16 changes: 8 additions & 8 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ const starting_coords_f32 = [Float32.(c) for c in starting_coords]
const starting_velocities_f32 = [Float32.(c) for c in starting_velocities]

function test_sim(nl::Bool, parallel::Bool, f32::Bool,
ArrayType::Type{AT}) where AT <: AbstractArray
array_type::Type{AT}) where AT <: AbstractArray
n_atoms = 400
n_steps = 200
atom_mass = f32 ? 10.0f0u"g/mol" : 10.0u"g/mol"
Expand All @@ -73,26 +73,26 @@ function test_sim(nl::Bool, parallel::Bool, f32::Bool,
r0 = f32 ? 0.2f0u"nm" : 0.2u"nm"
bonds = [HarmonicBond(k=k, r0=r0) for i in 1:(n_atoms ÷ 2)]
specific_inter_lists = (InteractionList2Atoms(
ArrayType(Int32.(collect(1:2:n_atoms))),
ArrayType(Int32.(collect(2:2:n_atoms))),
ArrayType(bonds),
array_type(Int32.(collect(1:2:n_atoms))),
array_type(Int32.(collect(2:2:n_atoms))),
array_type(bonds),
),)

neighbor_finder = NoNeighborFinder()
cutoff = DistanceCutoff(f32 ? 1.0f0u"nm" : 1.0u"nm")
pairwise_inters = (LennardJones(use_neighbors=false, cutoff=cutoff),)
if nl
neighbor_finder = DistanceNeighborFinder(
eligible=ArrayType(trues(n_atoms, n_atoms)),
eligible=array_type(trues(n_atoms, n_atoms)),
n_steps=10,
dist_cutoff=f32 ? 1.5f0u"nm" : 1.5u"nm",
)
pairwise_inters = (LennardJones(use_neighbors=true, cutoff=cutoff),)
end

coords = ArrayType(deepcopy(f32 ? starting_coords_f32 : starting_coords))
velocities = ArrayType(deepcopy(f32 ? starting_velocities_f32 : starting_velocities))
atoms = ArrayType([Atom(charge=f32 ? 0.0f0 : 0.0, mass=atom_mass, σ=f32 ? 0.2f0u"nm" : 0.2u"nm",
coords = array_type(deepcopy(f32 ? starting_coords_f32 : starting_coords))
velocities = array_type(deepcopy(f32 ? starting_velocities_f32 : starting_velocities))
atoms = array_type([Atom(charge=f32 ? 0.0f0 : 0.0, mass=atom_mass, σ=f32 ? 0.2f0u"nm" : 0.2u"nm",
ϵ=f32 ? 0.2f0u"kJ * mol^-1" : 0.2u"kJ * mol^-1") for i in 1:n_atoms])

sys = System(
Expand Down
4 changes: 2 additions & 2 deletions benchmark/protein.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const data_dir = normpath(dirname(pathof(Molly)), "..", "data")
const ff_dir = joinpath(data_dir, "force_fields")
const openmm_dir = joinpath(data_dir, "openmm_6mrr")

function setup_system(ArrayType::AbstractArray, f32::Bool, units::Bool)
function setup_system(array_type::AbstractArray, f32::Bool, units::Bool)
T = f32 ? Float32 : Float64
ff = MolecularForceField(
T,
Expand All @@ -27,7 +27,7 @@ function setup_system(ArrayType::AbstractArray, f32::Bool, units::Bool)
sys = System(
joinpath(data_dir, "6mrr_equil.pdb"),
ff;
velocities=ArrayType(velocities),
velocities=array_type(velocities),
units=units,
gpu=gpu,
dist_cutoff=(units ? dist_cutoff * u"nm" : dist_cutoff),
Expand Down
20 changes: 10 additions & 10 deletions src/interactions/implicit_solvent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,10 @@ function ImplicitSolventOBC(atoms::AbstractArray{Atom{TY, M, T, D, E}},
end

if isa(atoms, AbstractGPUArray)
ArrayType = get_array_type(atoms)
or = ArrayType(offset_radii)
sor = ArrayType(scaled_offset_radii)
is, js = ArrayType(inds_i), ArrayType(inds_j)
array_type = get_array_type(atoms)
or = array_type(offset_radii)
sor = array_type(scaled_offset_radii)
is, js = array_type(inds_i), array_type(inds_j)
else
or = offset_radii
sor = scaled_offset_radii
Expand Down Expand Up @@ -565,12 +565,12 @@ function ImplicitSolventGBN2(atoms::AbstractArray{Atom{TY, M, T, D, E}},
end

if isa(atoms, AbstractGPUArray)
ArrayType = get_array_type(atoms)
or = ArrayType(offset_radii)
sor = ArrayType(scaled_offset_radii)
is, js = ArrayType(inds_i), ArrayType(inds_j)
d0s, m0s = ArrayType(table_d0), ArrayType(table_m0)
αs, βs, γs = ArrayType(αs_cpu), ArrayType(βs_cpu), ArrayType(γs_cpu)
array_type = get_array_type(atoms)
or = array_type(offset_radii)
sor = array_type(scaled_offset_radii)
is, js = array_type(inds_i), array_type(inds_j)
d0s, m0s = array_type(table_d0), array_type(table_m0)
αs, βs, γs = array_type(αs_cpu), array_type(βs_cpu), array_type(γs_cpu)
else
or = offset_radii
sor = scaled_offset_radii
Expand Down
84 changes: 42 additions & 42 deletions src/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ are not available when reading Gromacs files.
- `loggers=()`: the loggers that record properties of interest during a
simulation.
- `units::Bool=true`: whether to use Unitful quantities.
- `ArrayType::AbstractArray = Array`: The ArrayType desired for the simulation
- `array_type::AbstractArray = Array`: The array_type desired for the simulation
(for GPU support, use CuArray or ROCArray)
- `dist_cutoff=1.0u"nm"`: cutoff distance for long-range interactions.
- `dist_neighbors=1.2u"nm"`: cutoff distance for the neighbor list, should be
Expand All @@ -452,7 +452,7 @@ function System(coord_file::AbstractString,
velocities=nothing,
loggers=(),
units::Bool=true,
ArrayType::Type{AT} where AT <: AbstractArray = Array,
array_type::Type{AT} where AT <: AbstractArray = Array,
dist_cutoff=units ? 1.0u"nm" : 1.0,
dist_neighbors=units ? 1.2u"nm" : 1.2,
center_coords::Bool=true,
Expand Down Expand Up @@ -824,9 +824,9 @@ function System(coord_file::AbstractString,
specific_inter_array = []
if length(bonds.is) > 0
push!(specific_inter_array, InteractionList2Atoms(
ArrayType(bonds.is),
ArrayType(bonds.js),
ArrayType([bonds.inters...]),
array_type(bonds.is),
array_type(bonds.js),
array_type([bonds.inters...]),
bonds.types,
))
topology = MolecularTopology(bonds.is, bonds.js, n_atoms)
Expand All @@ -835,30 +835,30 @@ function System(coord_file::AbstractString,
end
if length(angles.is) > 0
push!(specific_inter_array, InteractionList3Atoms(
ArrayType(angles.is),
ArrayType(angles.js),
ArrayType(angles.ks),
ArrayType([angles.inters...]),
array_type(angles.is),
array_type(angles.js),
array_type(angles.ks),
array_type([angles.inters...]),
angles.types,
))
end
if length(torsions.is) > 0
push!(specific_inter_array, InteractionList4Atoms(
ArrayType(torsions.is),
ArrayType(torsions.js),
ArrayType(torsions.ks),
ArrayType(torsions.ls),
ArrayType(torsion_inters_pad),
array_type(torsions.is),
array_type(torsions.js),
array_type(torsions.ks),
array_type(torsions.ls),
array_type(torsion_inters_pad),
torsions.types,
))
end
if length(impropers.is) > 0
push!(specific_inter_array, InteractionList4Atoms(
ArrayType(impropers.is),
ArrayType(impropers.js),
ArrayType(impropers.ks),
ArrayType(impropers.ls),
ArrayType(improper_inters_pad),
array_type(impropers.is),
array_type(impropers.js),
array_type(impropers.ks),
array_type(impropers.ls),
array_type(improper_inters_pad),
impropers.types,
))
end
Expand Down Expand Up @@ -887,10 +887,10 @@ function System(coord_file::AbstractString,
end
coords = wrap_coords.(coords, (boundary_used,))

if (ArrayType <: AbstractGPUArray) || !use_cell_list
if (array_type <: AbstractGPUArray) || !use_cell_list
neighbor_finder = DistanceNeighborFinder(
eligible=(ArrayType(eligible)),
special=(ArrayType(special)),
eligible=(array_type(eligible)),
special=(array_type(special)),
n_steps=10,
dist_cutoff=T(dist_neighbors),
)
Expand All @@ -905,8 +905,8 @@ function System(coord_file::AbstractString,
)
end

atoms = ArrayType([atoms_abst...])
coords_dev = ArrayType(coords)
atoms = array_type([atoms_abst...])
coords_dev = array_type(coords)

if isnothing(velocities)
if units
Expand Down Expand Up @@ -961,7 +961,7 @@ function System(T::Type,
velocities=nothing,
loggers=(),
units::Bool=true,
ArrayType::Type{AT} where AT <: AbstractArray = Array,
array_type::Type{AT} where AT <: AbstractArray = Array,
dist_cutoff=units ? 1.0u"nm" : 1.0,
dist_neighbors=units ? 1.2u"nm" : 1.2,
center_coords::Bool=true,
Expand Down Expand Up @@ -1242,9 +1242,9 @@ function System(T::Type,
specific_inter_array = []
if length(bonds.is) > 0
push!(specific_inter_array, InteractionList2Atoms(
ArrayType(bonds.is),
ArrayType(bonds.js),
ArrayType([bonds.inters...]),
array_type(bonds.is),
array_type(bonds.js),
array_type([bonds.inters...]),
bonds.types,
))
topology = MolecularTopology(bonds.is, bonds.js, n_atoms)
Expand All @@ -1253,29 +1253,29 @@ function System(T::Type,
end
if length(angles.is) > 0
push!(specific_inter_array, InteractionList3Atoms(
ArrayType(angles.is),
ArrayType(angles.js),
ArrayType(angles.ks),
ArrayType([angles.inters...]),
array_type(angles.is),
array_type(angles.js),
array_type(angles.ks),
array_type([angles.inters...]),
angles.types,
))
end
if length(torsions.is) > 0
push!(specific_inter_array, InteractionList4Atoms(
ArrayType(torsions.is),
ArrayType(torsions.js),
ArrayType(torsions.ks),
ArrayType(torsions.ls),
ArrayType([torsions.inters...]),
array_type(torsions.is),
array_type(torsions.js),
array_type(torsions.ks),
array_type(torsions.ls),
array_type([torsions.inters...]),
torsions.types,
))
end
specific_inter_lists = tuple(specific_inter_array...)

if ArrayType <: AbstractGPUArray || !use_cell_list
if array_type <: AbstractGPUArray || !use_cell_list
neighbor_finder = DistanceNeighborFinder(
eligible=(ArrayType(eligible)),
special=(ArrayType(special)),
eligible=(array_type(eligible)),
special=(array_type(special)),
n_steps=10,
dist_cutoff=T(dist_neighbors),
)
Expand All @@ -1290,8 +1290,8 @@ function System(T::Type,
)
end

atoms = ArrayType([atoms_abst...])
coords_dev = ArrayType(coords)
atoms = array_type([atoms_abst...])
coords_dev = array_type(coords)

if isnothing(velocities)
if units
Expand Down
4 changes: 2 additions & 2 deletions src/spatial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -876,8 +876,8 @@ function molecule_centers(coords::AbstractArray{SVector{D, C}}, boundary, topolo
end

function molecule_centers(coords::AbstractGPUArray, boundary, topology)
ArrayType = get_array_type(coords)
return ArrayType(molecule_centers(Array(coords), boundary, topology))
array_type = get_array_type(coords)
return array_type(molecule_centers(Array(coords), boundary, topology))
end

# Allows scaling multiple vectors at once by broadcasting this function
Expand Down
16 changes: 8 additions & 8 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,23 +183,23 @@ function Base.:+(il1::InteractionList4Atoms{I, T}, il2::InteractionList4Atoms{I,
)
end

function inject_interaction_list(inter::InteractionList1Atoms, params_dic, ArrayType)
inters_grad = ArrayType(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
function inject_interaction_list(inter::InteractionList1Atoms, params_dic, array_type)
inters_grad = array_type(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
InteractionList1Atoms(inter.is, inters_grad, inter.types)
end

function inject_interaction_list(inter::InteractionList2Atoms, params_dic, ArrayType)
inters_grad = ArrayType(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
function inject_interaction_list(inter::InteractionList2Atoms, params_dic, array_type)
inters_grad = array_type(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
InteractionList2Atoms(inter.is, inter.js, inters_grad, inter.types)
end

function inject_interaction_list(inter::InteractionList3Atoms, params_dic, ArrayType)
inters_grad = ArrayType(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
function inject_interaction_list(inter::InteractionList3Atoms, params_dic, array_type)
inters_grad = array_type(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
InteractionList3Atoms(inter.is, inter.js, inter.ks, inters_grad, inter.types)
end

function inject_interaction_list(inter::InteractionList4Atoms, params_dic, ArrayType)
inters_grad = ArrayType(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
function inject_interaction_list(inter::InteractionList4Atoms, params_dic, array_type)
inters_grad = array_type(inject_interaction.(Array(inter.inters), inter.types, (params_dic,)))
InteractionList4Atoms(inter.is, inter.js, inter.ks, inter.ls, inters_grad, inter.types)
end

Expand Down
10 changes: 5 additions & 5 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,16 @@
@test mcs == [SVector(0.05, 0.0), SVector(1.0, 1.0)]

ff = MolecularForceField(joinpath.(ff_dir, ["ff99SBildn.xml", "tip3p_standard.xml", "his.xml"])...)
for ArrayType in array_list
sys = System(joinpath(data_dir, "6mrr_equil.pdb"), ff; ArrayType=ArrayType, use_cell_list=false)
for array_type in array_list
sys = System(joinpath(data_dir, "6mrr_equil.pdb"), ff; array_type=array_type, use_cell_list=false)
mcs = molecule_centers(sys.coords, sys.boundary, sys.topology)
@test isapprox(Array(mcs)[1], mean(sys.coords[1:1170]); atol=0.08u"nm")

# Mark all pairs as ineligible for pairwise interactions and check that the
# potential energy from the specific interactions does not change on scaling
no_nbs = falses(length(sys), length(sys))
sys.neighbor_finder = DistanceNeighborFinder(
eligible=(ArrayType(no_nbs)),
eligible=(array_type(no_nbs)),
dist_cutoff=1.0u"nm",
)
coords_start = copy(sys.coords)
Expand Down Expand Up @@ -312,7 +312,7 @@ end

if run_cuda_tests
sys_gpu = System(joinpath(data_dir, "6mrr_equil.pdb"), ff;
ArrayType=CuArray)
array_type=CuArray)
for neighbor_finder in (DistanceNeighborFinder,)
nf_gpu = neighbor_finder(
eligible=sys_gpu.neighbor_finder.eligible,
Expand All @@ -330,7 +330,7 @@ end

if run_rocm_tests
sys_gpu = System(joinpath(data_dir, "6mrr_equil.pdb"), ff;
ArrayType=ROCArray)
array_type=ROCArray)
for neighbor_finder in (DistanceNeighborFinder,)
nf_gpu = neighbor_finder(
eligible=sys_gpu.neighbor_finder.eligible,
Expand Down
6 changes: 3 additions & 3 deletions test/energy_conservation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using CUDA
using Test

@testset "Lennard-Jones energy conservation" begin
function test_energy_conservation(ArrayType::AbstractArray, n_threads::Integer, n_steps::Integer)
function test_energy_conservation(array_type::AbstractArray, n_threads::Integer, n_steps::Integer)
n_atoms = 2_000
atom_mass = 40.0u"g/mol"
temp = 1.0u"K"
Expand All @@ -26,8 +26,8 @@ using Test
coords = place_atoms(n_atoms, boundary; min_dist=0.6u"nm")

sys = System(
atoms=(ArrayType(atoms) : atoms),
coords=(ArrayType(coords) : coords),
atoms=(array_type(atoms) : atoms),
coords=(array_type(coords) : coords),
boundary=boundary,
pairwise_inters=(LennardJones(cutoff=cutoff, use_neighbors=false),),
loggers=(
Expand Down
4 changes: 2 additions & 2 deletions test/gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,13 @@ end
end

@testset "Differentiable protein" begin
function create_sys(ArrayType)
function create_sys(array_type)
ff = MolecularForceField(joinpath.(ff_dir, ["ff99SBildn.xml", "his.xml"])...; units=false)
return System(
joinpath(data_dir, "6mrr_nowater.pdb"),
ff;
units=false,
ArrayType=ArrayType,
array_type=array_type,
implicit_solvent="gbn2",
kappa=0.7,
)
Expand Down
Loading

0 comments on commit 5c1b912

Please sign in to comment.