diff --git a/src/CellLists.jl b/src/CellLists.jl index a4bd9c0f..5b9be77f 100644 --- a/src/CellLists.jl +++ b/src/CellLists.jl @@ -1,5 +1,5 @@ # -# This file contains all structre types and functions necessary for building +# This file contains all structure types and functions necessary for building # the CellList and CellListPair structures. # @@ -157,35 +157,25 @@ end #= -Structures to control dispatch on swapped vs. not swapped cell list pairs. - -=# -struct Swapped end -struct NotSwapped end - -#= - $(TYPEDEF) # Extended help $(TYPEDFIELDS) -Structure that will cointain the cell lists of two independent sets of +Structure that will contains the cell lists of two independent sets of particles for cross-computation of interactions =# -struct CellListPair{V,N,T,Swap} - ref::V +struct CellListPair{N,T} + ref::CellList{N,T} target::CellList{N,T} end -CellListPair(ref::V, target::CellList{N,T}, ::Swap) where {V,N,T,Swap} = - CellListPair{V,N,T,Swap}(ref, target) function Base.show(io::IO, ::MIME"text/plain", cl::CellListPair) _print(io, typeof(cl), "\n") - _print(io, " $(length(cl.ref)) particles in the reference vector.\n") - _print(io, " $(cl.target.n_cells_with_real_particles) cells with real particles of target vector.") + _print(io, " $(cl.ref.n_cells_with_real_particles) cells with real particles in smallest set.\n") + _print(io, " $(cl.target.n_cells_with_real_particles) cells with real particles largest set.") end #= @@ -230,30 +220,13 @@ _nbatches_build_cell_lists(n::Int) = max(1, min(n, min(8, nthreads()))) _nbatches_map_computation(n::Int) = max(1, min(n, min(floor(Int, 2^(log10(n) + 1)), nthreads()))) function set_number_of_batches!( - cl::CellListPair{V,N,T,Swap}, + cl::CellListPair{N,T}, nbatches::Tuple{Int,Int}=(0, 0); parallel=true -) where {V,N,T,Swap} - if parallel - nbatches = NumberOfBatches(nbatches) - else - if nbatches != (0, 0) && nbatches != (1, 1) - println("WARNING: nbatches set to $nbatches, but parallel is set to false, implying nbatches == (1, 1)") - end - nbatches = NumberOfBatches((1, 1)) - end - if nbatches.build_cell_lists < 1 - n1 = _nbatches_build_cell_lists(cl.target.n_real_particles) - else - n1 = nbatches.build_cell_lists - end - if nbatches.map_computation < 1 - n2 = _nbatches_map_computation(length(cl.ref)) - else - n2 = nbatches.map_computation - end - cl.target.nbatches = NumberOfBatches(n1, n2) - return CellListPair{V,N,T,Swap}(cl.ref, cl.target) +) where {N,T} + cl.ref = set_number_of_batches!(cl.ref, nbatches; parallel) + cl.target = set_number_of_batches!(cl.target, nbatches; parallel) + return CellListPair{N,T}(cl.ref, cl.target) end """ @@ -290,8 +263,8 @@ function nbatches(cl::CellList, s::Symbol) s == :map_computation || s == :map && return cl.nbatches.map_computation s == :build_cell_lists || s == :build && return cl.nbatches.build_cell_lists end -nbatches(cl::CellListPair) = nbatches(cl.target) -nbatches(cl::CellListPair, s::Symbol) = nbatches(cl.target, s) +nbatches(cl::CellListPair) = (nbatches(cl.ref), nbatches(cl.target)) +nbatches(cl::CellListPair, s::Symbol) = (nbatches(cl.ref, s), nbatches(cl.target, s)) #= @@ -306,11 +279,19 @@ be considered by each thread on parallel construction. =# @with_kw struct AuxThreaded{N,T} - particles_per_batch::Int idxs::Vector{UnitRange{Int}} = Vector{UnitRange{Int}}(undef, 0) lists::Vector{CellList{N,T}} = Vector{CellList{N,T}}(undef, 0) end -function Base.show(io::IO, ::MIME"text/plain", aux::AuxThreaded) +function Base.show(io::IO, ::MIME"text/plain", aux::AuxThreaded{<:CellList}) + _println(io, typeof(aux)) + _print(io, " Auxiliary arrays for nbatches = ", length(aux.lists)) +end + +@with_kw struct AuxThreadedPair{N,T} + ref::AuxThreaded{N,T} + target::AuxThreaded{N,T} +end +function Base.show(io::IO, ::MIME"text/plain", aux::AuxThreaded{<:CellList}) _println(io, typeof(aux)) _print(io, " Auxiliary arrays for nbatches = ", length(aux.lists)) end @@ -341,10 +322,9 @@ CellList{3, Float64} ``` """ -function AuxThreaded(cl::CellList{N,T}; particles_per_batch=10_000) where {N,T} +function AuxThreaded(cl::CellList{N,T}) where {N,T} nbatches = cl.nbatches.build_cell_lists aux = AuxThreaded{N,T}( - particles_per_batch=particles_per_batch, idxs=Vector{UnitRange{Int}}(undef, nbatches), lists=Vector{CellList{N,T}}(undef, nbatches) ) @@ -352,10 +332,7 @@ function AuxThreaded(cl::CellList{N,T}; particles_per_batch=10_000) where {N,T} nbatches == 1 && return aux @sync for ibatch in eachindex(aux.lists) @spawn begin - cl_batch = CellList{N,T}( - n_real_particles=particles_per_batch, # this is reset before filling, in UpdateCellList! - number_of_cells=cl.number_of_cells, - ) + cl_batch = CellList{N,T}(number_of_cells=cl.number_of_cells) aux.lists[ibatch] = cl_batch end end @@ -420,8 +397,7 @@ CellList{3, Float64} ``` """ -AuxThreaded(cl_pair::CellListPair; particles_per_batch=10_000) = - AuxThreaded(cl_pair.target, particles_per_batch=particles_per_batch) +AuxThreaded(cl_pair::CellListPair) = AuxThreaded(cl_pair.target) """ CellList(