Skip to content

Commit

Permalink
Use Polyester.@batch for threading
Browse files Browse the repository at this point in the history
  • Loading branch information
efaulhaber committed Jul 5, 2024
1 parent 40f69dc commit 42d2e08
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 55 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
76 changes: 37 additions & 39 deletions src/CellListMap.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
module CellListMap

using DocStringExtensions: TYPEDEF, TYPEDFIELDS
using DocStringExtensions: TYPEDEF, TYPEDFIELDS
using TestItems: @testitem
using Compat: @compat
using Polyester: @batch
using ProgressMeter: Progress, next!
using Parameters: @unpack, @with_kw
using StaticArrays: SVector, SMatrix, @SVector, @SMatrix, MVector, MMatrix, FieldVector
using Setfield: @set!
using LinearAlgebra: cross, diagm, I
using Base.Threads: nthreads, @spawn
using Base.Threads: nthreads, @spawn
using Base: @lock # not exported in 1.6

export Box
Expand All @@ -24,15 +25,15 @@ export nbatches
# Testing file
const argon_pdb_file = joinpath("$(@__DIR__ )/../test/gromacs/argon/cubic.pdb")

# name holder
# name holder
function map_pairwise! end

"""
map_pairwise(args...;kargs...) = map_pairwise!(args...;kargs...)
is an alias for `map_pairwise!` which is defined for two reasons: first, if the output of the funciton is immutable, it may be
clearer to call this version, from a coding perspective. Second, the python interface through `juliacall` does not accept the
bang as a valid character.
is an alias for `map_pairwise!` which is defined for two reasons: first, if the output of the funciton is immutable, it may be
clearer to call this version, from a coding perspective. Second, the python interface through `juliacall` does not accept the
bang as a valid character.
"""
const map_pairwise = map_pairwise!
Expand All @@ -55,32 +56,32 @@ include("./CoreComputing.jl")
show_progress::Bool=false
)
This function will run over every pair of particles which are closer than
`box.cutoff` and compute the Euclidean distance between the particles,
considering the periodic boundary conditions given in the `Box` structure.
If the distance is smaller than the cutoff, a function `f` of the
coordinates of the two particles will be computed.
This function will run over every pair of particles which are closer than
`box.cutoff` and compute the Euclidean distance between the particles,
considering the periodic boundary conditions given in the `Box` structure.
If the distance is smaller than the cutoff, a function `f` of the
coordinates of the two particles will be computed.
The function `f` receives six arguments as input:
The function `f` receives six arguments as input:
```
f(x,y,i,j,d2,output)
```
Which are the coordinates of one particle, the coordinates of the
second particle, the index of the first particle, the index of the second
particle, the squared distance between them, and the `output` variable.
It has also to return the same `output` variable. Thus, `f` may or not
mutate `output`, but in either case it must return it. With that, it is
possible to compute an average property of the distance of the particles
or, for example, build a histogram. The squared distance `d2` is computed
internally for comparison with the
`cutoff`, and is passed to the `f` because many times it is used for the
desired computation.
Which are the coordinates of one particle, the coordinates of the
second particle, the index of the first particle, the index of the second
particle, the squared distance between them, and the `output` variable.
It has also to return the same `output` variable. Thus, `f` may or not
mutate `output`, but in either case it must return it. With that, it is
possible to compute an average property of the distance of the particles
or, for example, build a histogram. The squared distance `d2` is computed
internally for comparison with the
`cutoff`, and is passed to the `f` because many times it is used for the
desired computation.
## Example
Computing the mean absolute difference in `x` position between random particles,
remembering the number of pairs of `n` particles is `n(n-1)/2`. The function does
not use the indices or the distance, such that we remove them from the parameters
Computing the mean absolute difference in `x` position between random particles,
remembering the number of pairs of `n` particles is `n(n-1)/2`. The function does
not use the indices or the distance, such that we remove them from the parameters
by using a closure.
```julia-repl
Expand All @@ -101,7 +102,7 @@ julia> avg_dx = normalization * map_parwise!((x,y,i,j,d2,sum_dx) -> f(x,y,sum_dx
```
"""
function map_pairwise!(f::F, output, box::Box, cl::CellList;
function map_pairwise!(f::F, output, box::Box, cl::CellList;
# Parallelization options
parallel::Bool=true,
output_threaded=nothing,
Expand Down Expand Up @@ -135,20 +136,20 @@ function map_pairwise!(f::F1, output, box::Box, cl::CellListPair{V,N,T,Swap};
show_progress::Bool=false
) where {F1,F2,V,N,T,Swap} # F1, F2 Needed for specialization for these functions
if Swap == Swapped
fswap(x,y,i,j,d2,output) = f(y,x,j,i,d2,output)
fswap(x,y,i,j,d2,output) = f(y,x,j,i,d2,output)
else
fswap = f
end
if parallel
output = map_pairwise_parallel!(
fswap,output,box,cl;
output_threaded=output_threaded,
reduce=reduce,
show_progress=show_progress
)
else
# if parallel
# output = map_pairwise_parallel!(
# fswap,output,box,cl;
# output_threaded=output_threaded,
# reduce=reduce,
# show_progress=show_progress
# )
# else
output = map_pairwise_serial!(fswap,output,box,cl,show_progress=show_progress)
end
# end
return output
end

Expand All @@ -165,6 +166,3 @@ include("./testing.jl")
include("precompile.jl")

end # module



30 changes: 14 additions & 16 deletions src/CoreComputing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,24 @@ splitter(first, nbatches, n) = first:nbatches:n
#=
reduce(output, output_threaded)
Most common reduction function, which sums the elements of the output.
Most common reduction function, which sums the elements of the output.
Here, `output_threaded` is a vector containing `nbatches(cl)` copies of
the `output` variable (a scalar or an array). Custom reduction functions
must replace this one if the reduction operation is not a simple sum.
the `output` variable (a scalar or an array). Custom reduction functions
must replace this one if the reduction operation is not a simple sum.
The `output_threaded` array is, by default, created automatically by copying
the given `output` variable `nbatches(cl)` times.
the given `output` variable `nbatches(cl)` times.
## Examples
Scalar reduction:
Scalar reduction:
```julia-repl
julia> output = 0.; output_threaded = [ 1, 2 ];
julia> CellListMap.reduce(output,output_threaded)
3
```
Array reduction:
```julia-repl
Expand Down Expand Up @@ -60,8 +60,8 @@ function reduce(output, output_threaded)
custom_reduce(output::$T, output_threaded::Vector{$T})
```
The reduction function **must** return the `output` variable, even
if it is mutable.
The reduction function **must** return the `output` variable, even
if it is mutable.
See: https://m3g.github.io/CellListMap.jl/stable/parallelization/#Custom-reduction-functions
Expand Down Expand Up @@ -158,8 +158,8 @@ function map_pairwise_serial!(
show_progress::Bool=false
) where {F,N,T}
p = show_progress ? Progress(length(cl.ref), dt=1) : nothing
for i in eachindex(cl.ref)
output = inner_loop!(f, output, i, box, cl)
@batch for i in eachindex(cl.ref)
inner_loop!(f, output, i, box, cl)
_next!(p)
end
return output
Expand Down Expand Up @@ -289,7 +289,7 @@ function _vinicial_cells!(f::F, box::Box{<:OrthorhombicCellType}, cellᵢ, pp,
xproj = dot(xpᵢ - cellᵢ.center, Δc)
# Partition pp array according to the current projections
n = partition!(el -> abs(el.xproj - xproj) <= cutoff, pp)
# Compute the interactions
# Compute the interactions
for j in 1:n
@inbounds pⱼ = pp[j]
xpⱼ = pⱼ.coordinates
Expand All @@ -312,7 +312,7 @@ function _vinicial_cells!(f::F, box::Box{<:TriclinicCell}, cellᵢ, pp, Δc, out
xproj = dot(xpᵢ - cellᵢ.center, Δc)
# Partition pp array according to the current projections
n = partition!(el -> abs(el.xproj - xproj) <= cutoff, pp)
# Compute the interactions
# Compute the interactions
pᵢ.real || continue
for j in 1:n
@inbounds pⱼ = pp[j]
Expand All @@ -332,8 +332,8 @@ end
# Extended help
Projects all particles of the cell `cellⱼ` into unnitary vector `Δc` with direction
connecting the centers of `cellⱼ` and `cellᵢ`. Modifies `projected_particles`, and
Projects all particles of the cell `cellⱼ` into unnitary vector `Δc` with direction
connecting the centers of `cellⱼ` and `cellᵢ`. Modifies `projected_particles`, and
returns a view of `projected particles, where only the particles for which
the projection on the direction of the cell centers still allows the particle
to be within the cutoff distance of any point of the other cell.
Expand Down Expand Up @@ -393,5 +393,3 @@ function inner_loop!(
end
return output
end


0 comments on commit 42d2e08

Please sign in to comment.