Skip to content

Commit

Permalink
use ChunkSplitters for splitting batches
Browse files Browse the repository at this point in the history
  • Loading branch information
lmiq committed Dec 9, 2024
1 parent 1cd4a1d commit 8bf24d4
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Leandro Martinez <[email protected]> and contributors"]
version = "0.9.7-DEV"

[deps]
ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -19,6 +20,7 @@ TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
Aqua = "0.8.5"
BenchmarkTools = "1.4"
Chemfiles = "0.10.31"
ChunkSplitters = "3.1.0"
Compat = "4.14.0"
DocStringExtensions = "0.9"
Documenter = "1.2.1"
Expand Down
1 change: 1 addition & 0 deletions src/CellListMap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using Setfield: @set!
using LinearAlgebra: cross, diagm, I
using Base.Threads: nthreads, @spawn
using Base: @lock # not exported in 1.6
using ChunkSplitters: index_chunks, RoundRobin, Consecutive

export Box
export CellList, UpdateCellList!
Expand Down
21 changes: 8 additions & 13 deletions src/CoreComputing.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
#
# Parallel thread spliiter
#
splitter(first, nbatches, n) = first:nbatches:n

#=
reduce(output, output_threaded)
Expand Down Expand Up @@ -118,8 +113,8 @@ end
#
# Parallel version for self-pairwise computations
#
function batch(f::F, ibatch, nbatches, n_cells_with_real_particles, output_threaded, box, cl, p) where {F}
for i in splitter(ibatch, nbatches, n_cells_with_real_particles)
function batch(f::F, ibatch, cell_indices, output_threaded, box, cl, p) where {F}
for i in cell_indices
cellᵢ = cl.cells[cl.cell_indices_real[i]]
output_threaded[ibatch] = inner_loop!(f, box, cellᵢ, cl, output_threaded[ibatch], ibatch)
_next!(p)
Expand All @@ -143,8 +138,8 @@ function map_pairwise_parallel!(
@unpack n_cells_with_real_particles = cl
nbatches = cl.nbatches.map_computation
p = show_progress ? Progress(n_cells_with_real_particles, dt=1) : nothing
@sync for ibatch in 1:nbatches
@spawn batch($f, $ibatch, $nbatches, $n_cells_with_real_particles, $output_threaded, $box, $cl, $p)
@sync for (ibatch, cell_indices) in enumerate(index_chunks(1:n_cells_with_real_particles; n=nbatches, split=RoundRobin()))
@spawn batch($f, $ibatch, $cell_indices, $output_threaded, $box, $cl, $p)
end
return reduce(output, output_threaded)
end
Expand All @@ -168,8 +163,8 @@ end
#
# Parallel version for cross-interaction computations
#
function batch(f::F, ibatch, nbatches, output_threaded, box, cl, p) where {F}
for i in splitter(ibatch, nbatches, length(cl.ref))
function batch_cross(f::F, ibatch, ref_atom_indices, output_threaded, box, cl, p) where {F}
for i in ref_atom_indices
output_threaded[ibatch] = inner_loop!(f, output_threaded[ibatch], i, box, cl)
_next!(p)
end
Expand All @@ -187,8 +182,8 @@ function map_pairwise_parallel!(
output_threaded = [deepcopy(output) for i in 1:nbatches]
end
p = show_progress ? Progress(length(cl.ref), dt=1) : nothing
@sync for ibatch in 1:nbatches
@spawn batch($f, $ibatch, $nbatches, $output_threaded, $box, $cl, $p)
@sync for (ibatch, ref_atom_indices) in enumerate(index_chunks(1:length(cl.ref); n=nbatches, split=Consecutive()))
@spawn batch_cross($f, $ibatch, $ref_atom_indices, $output_threaded, $box, $cl, $p)
end
return reduce(output, output_threaded)
end
Expand Down

0 comments on commit 8bf24d4

Please sign in to comment.