Skip to content

Commit

Permalink
implement validate_coordinates
Browse files Browse the repository at this point in the history
  • Loading branch information
lmiq committed Aug 6, 2024
1 parent 1b9e43d commit 0f6554b
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 67 deletions.
136 changes: 96 additions & 40 deletions src/CellLists.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,8 @@ AuxThreaded(cl_pair::CellListPair; particles_per_batch=10_000) =
x::AbstractVector{AbstractVector},
box::Box{UnitCellType,N,T};
parallel::Bool=true,
nbatches::Tuple{Int,Int}=(0,0)
nbatches::Tuple{Int,Int}=(0,0),
validate_coordinates::Union{Function,Nothing}=_validate_coordinates
) where {UnitCellType,N,T}
Function that will initialize a `CellList` structure from scratch, given a vector
Expand Down Expand Up @@ -457,14 +458,12 @@ function CellList(
x::AbstractVector{<:AbstractVector},
box::Box{UnitCellType,N,T};
parallel::Bool=true,
nbatches::Tuple{Int,Int}=(0, 0)
nbatches::Tuple{Int,Int}=(0, 0),
validate_coordinates::Union{Function,Nothing}=_validate_coordinates,
) where {UnitCellType,N,T}
cl = CellList{N,T}(
n_real_particles=length(x),
number_of_cells=prod(box.nc),
)
set_number_of_batches!(cl, nbatches, parallel=parallel)
return UpdateCellList!(x, box, cl, parallel=parallel)
cl = CellList{N,T}(n_real_particles=length(x), number_of_cells=prod(box.nc))
set_number_of_batches!(cl, nbatches; parallel)
return UpdateCellList!(x, box, cl; parallel, validate_coordinates)
end

#=
Expand Down Expand Up @@ -515,7 +514,8 @@ end
box::Box{UnitCellType,N,T};
parallel::Bool=true,
nbatches::Tuple{Int,Int}=(0,0),
autoswap::Bool=true
autoswap::Bool=true,
validate_coordinates::Union{Function,Nothing}=_validate_coordinates
) where {UnitCellType,N,T}
Function that will initialize a `CellListPair` structure from scratch, given two vectors
Expand Down Expand Up @@ -551,19 +551,22 @@ function CellList(
box::Box{UnitCellType,N,T};
parallel::Bool=true,
nbatches::Tuple{Int,Int}=(0, 0),
autoswap=true
autoswap=true,
validate_coordinates::Union{Function,Nothing}=_validate_coordinates,
) where {UnitCellType,N,T}
if !autoswap || length(x) >= length(y)
isnothing(validate_coordinates) || validate_coordinates(x)
ref = [SVector{N,T}(ntuple(i -> el[i], N)) for el in x]
target = CellList(y, box, parallel=parallel)
target = CellList(y, box; parallel, validate_coordinates)
swap = NotSwapped()
else
isnothing(validate_coordinates) || validate_coordinates(y)
ref = [SVector{N,T}(ntuple(i -> el[i], N)) for el in y]
target = CellList(x, box, parallel=parallel)
target = CellList(x, box; parallel, validate_coordinates)
swap = Swapped()
end
cl_pair = CellListPair(ref, target, swap)
cl_pair = set_number_of_batches!(cl_pair, nbatches, parallel=parallel)
cl_pair = set_number_of_batches!(cl_pair, nbatches; parallel)
return cl_pair
end

Expand All @@ -587,15 +590,21 @@ end
UpdateCellList!(
x::AbstractVector{<:AbstractVector},
box::Box,
cl:CellList,
parallel=true
cl:CellList;
parallel=true,
validate_coordinates::Union{Function,Nothing}=_validate_coordinates
)
Function that will update a previously allocated `CellList` structure, given new
updated particle positions. This function will allocate new threaded auxiliary
arrays in parallel calculations. To preallocate these auxiliary arrays, use
the `UpdateCellList!(x,box,cl,aux)` method instead.
The `validate_coordinates` function is called before the update of the cell list, and
should throw an error if the coordinates are invalid. By default, this function
throws an error if some coordinates are missing or are NaN. Set to `nothing` to disable
this check, or provide a custom function.
## Example
```julia-repl
Expand All @@ -618,13 +627,14 @@ function UpdateCellList!(
x::AbstractVector{<:AbstractVector},
box::Box,
cl::CellList;
parallel::Bool=true
parallel::Bool=true,
validate_coordinates = _validate_coordinates,
)
if parallel
aux = AuxThreaded(cl)
return UpdateCellList!(x, box, cl, aux, parallel=parallel)
return UpdateCellList!(x, box, cl, aux; parallel, validate_coordinates)
else
return UpdateCellList!(x, box, cl, nothing, parallel=parallel)
return UpdateCellList!(x, box, cl, nothing; parallel, validate_coordinates)
end
end

Expand All @@ -633,7 +643,8 @@ end
x::AbstractMatrix,
box::Box,
cl::CellList{N,T};
parallel::Bool=true
parallel::Bool=true,
validate_coordinates::Union{Function,Nothing}=_validate_coordinates,
) where {N,T}
Reinterprets the matrix `x` as vectors of static vectors and calls the
Expand All @@ -645,11 +656,11 @@ function UpdateCellList!(
x::AbstractMatrix,
box::Box,
cl::CellList{N,T};
parallel::Bool=true
kargs...
) where {N,T}
size(x, 1) == N || throw(DimensionMismatch("First dimension of input matrix must be $N"))
x_re = reinterpret(reshape, SVector{N,eltype(x)}, x)
return UpdateCellList!(x_re, box, cl, parallel=parallel)
return UpdateCellList!(x_re, box, cl; kargs...)
end

"""
Expand All @@ -658,7 +669,8 @@ end
box::Box,
cl::CellList{N,T},
aux::Union{Nothing,AuxThreaded{N,T}};
parallel::Bool=true
parallel::Bool=true,
validate_coordinates::Union{Function,Nothing}=_validate_coordinates
) where {N,T}
Function that updates the cell list `cl` new coordinates `x` and possibly a new
Expand Down Expand Up @@ -719,9 +731,13 @@ function UpdateCellList!(
box::Box,
cl::CellList{N,T},
aux::Union{Nothing,AuxThreaded{N,T}};
parallel::Bool=true
parallel::Bool=true,
validate_coordinates = _validate_coordinates,
) where {N,T}

# validate coordinates
isnothing(validate_coordinates) || validate_coordinates(x)

# Provide a better error message if the unit cell dimension does not match the dimension of the positions.
if length(x) > 0 && (length(x[begin]) != size(box.input_unit_cell.matrix, 1))
n1 = length(x[begin])
Expand Down Expand Up @@ -773,7 +789,8 @@ end
box::Box,
cl::CellList{N,T},
aux::Union{Nothing,AuxThreaded{N,T}};
parallel::Bool=true
parallel::Bool=true,
validate_coordinates=_validate_coordinates,
) where {N,T}
Reinterprets the matrix `x` as vectors of static vectors and calls the
Expand All @@ -786,11 +803,11 @@ function UpdateCellList!(
box::Box,
cl::CellList{N,T},
aux::Union{Nothing,AuxThreaded{N,T}};
parallel::Bool=true
kargs...
) where {N,T}
size(x, 1) == N || throw(DimensionMismatch("First dimension of input matrix must be $N"))
x_re = reinterpret(reshape, SVector{N,eltype(x)}, x)
return UpdateCellList!(x_re, box, cl, aux, parallel=parallel)
return UpdateCellList!(x_re, box, cl, aux; kargs...)
end

#=
Expand Down Expand Up @@ -1038,14 +1055,20 @@ end
y::AbstractVector{<:AbstractVector},
box::Box,
cl:CellListPair,
parallel=true
parallel=true,
validate_coordinates::Union{Function,Nothing}=_validate_coordinates
)
Function that will update a previously allocated `CellListPair` structure, given
new updated particle positions, for example. This method will allocate new
`aux` threaded auxiliary arrays. For a non-allocating version, see the
`UpdateCellList!(x,y,box,cl,aux)` method.
The `validate_coordinates` function is called before the update of the cell list, and
should throw an error if the coordinates are invalid. By default, this function
throws an error if some coordinates are missing or are NaN. Set to `nothing` to disable
this check, or provide a custom function.
```julia-repl
julia> box = Box([250,250,250],10);
Expand All @@ -1065,13 +1088,14 @@ function UpdateCellList!(
y::AbstractVector{<:AbstractVector},
box::Box,
cl_pair::CellListPair;
parallel::Bool=true
parallel::Bool=true,
kargs...
)
if parallel
aux = AuxThreaded(cl_pair)
return UpdateCellList!(x, y, box, cl_pair, aux, parallel=parallel)
return UpdateCellList!(x, y, box, cl_pair, aux; kargs...)
else
return UpdateCellList!(x, y, box, cl_pair, nothing, parallel=parallel)
return UpdateCellList!(x, y, box, cl_pair, nothing; kargs...)
end
end

Expand All @@ -1094,13 +1118,13 @@ function UpdateCellList!(
y::AbstractMatrix,
box::Box{UnitCellType,N},
cl_pair::CellListPair;
parallel::Bool=true
kargs...
) where {UnitCellType,N}
size(x, 1) == N || throw(DimensionMismatch("First dimension of input matrix must be $N"))
size(y, 1) == N || throw(DimensionMismatch("First dimension of input matrix must be $N"))
x_re = reinterpret(reshape, SVector{N,eltype(x)}, x)
y_re = reinterpret(reshape, SVector{N,eltype(y)}, y)
return UpdateCellList!(x_re, y_re, box, cl_pair, parallel=parallel)
return UpdateCellList!(x_re, y_re, box, cl_pair; kargs...)
end

"""
Expand All @@ -1110,7 +1134,8 @@ end
box::Box,
cl_pair::CellListPair,
aux::Union{Nothing,AuxThreaded};
parallel::Bool=true
parallel::Bool=true,
validate_coordinates::Union{Function,Nothing}=_validate_coordinates
)
This function will update the `cl_pair` structure that contains the cell lists
Expand Down Expand Up @@ -1174,10 +1199,12 @@ function UpdateCellList!(
box::Box,
cl_pair::CellListPair{V,N,T,Swap},
aux::Union{Nothing,AuxThreaded};
parallel::Bool=true
parallel::Bool=true,
validate_coordinates::Union{Nothing,Function}=_validate_coordinates,
) where {V,N,T,Swap<:NotSwapped}
ref = x
target = UpdateCellList!(y, box, cl_pair.target, aux, parallel=parallel)
isnothing(validate_coordinates) || validate_coordinates(x)
target = UpdateCellList!(y, box, cl_pair.target, aux; parallel, validate_coordinates)
cl_pair = _update_CellListPair!(ref, target, cl_pair)
return cl_pair
end
Expand All @@ -1188,10 +1215,12 @@ function UpdateCellList!(
box::Box,
cl_pair::CellListPair{V,N,T,Swap},
aux::Union{Nothing,AuxThreaded};
parallel::Bool=true
parallel::Bool=true,
validate_coordinates::Union{Nothing,Function}=_validate_coordinates,
) where {V,N,T,Swap<:Swapped}
ref = y
target = UpdateCellList!(x, box, cl_pair.target, aux, parallel=parallel)
isnothing(validate_coordinates) || validate_coordinates(y)
target = UpdateCellList!(x, box, cl_pair.target, aux; parallel, validate_coordinates)
cl_pair = _update_CellListPair!(ref, target, cl_pair)
return cl_pair
end
Expand Down Expand Up @@ -1231,13 +1260,13 @@ function UpdateCellList!(
box::Box{UnitCellType,N},
cl_pair::CellListPair,
aux::Union{Nothing,AuxThreaded};
parallel::Bool=true
kargs...
) where {UnitCellType,N}
size(x, 1) == N || throw(DimensionMismatch("First dimension of input matrix must be $N"))
size(y, 1) == N || throw(DimensionMismatch("First dimension of input matrix must be $N"))
x_re = reinterpret(reshape, SVector{N,eltype(x)}, x)
y_re = reinterpret(reshape, SVector{N,eltype(y)}, y)
return UpdateCellList!(x_re, y_re, box, cl_pair, aux, parallel=parallel)
return UpdateCellList!(x_re, y_re, box, cl_pair, aux; kargs...)
end

#=
Expand All @@ -1249,4 +1278,31 @@ Returns the average number of real particles per computing cell.
particles_per_cell(cl::CellList) = cl.n_real_particles / cl.number_of_cells
particles_per_cell(cl::CellListPair) = particles_per_cell(cl.target)


@testitem "celllists - validate coordinates" begin
using CellListMap
using StaticArrays
x = rand(SVector{3,Float64}, 100)
x[50] = SVector(1.0, NaN, 1.0)
box = Box([1.0, 1.0, 1.0], 0.1)
y = rand(SVector{3,Float64}, 100)
@test_throws ArgumentError CellList(x,box)
cl = CellList(y,box)
@test_throws ArgumentError UpdateCellList!(x, box, cl)
@test_throws ArgumentError CellList(x,y,box)
@test_throws ArgumentError CellList(y,x,box)
cl = CellList(y,y,box)
@test_throws ArgumentError UpdateCellList!(x, y, box, cl)
@test_throws ArgumentError UpdateCellList!(y, x, box, cl)
x = rand(3,100)
x[2,50] = NaN
box = Box([1.0, 1.0, 1.0], 0.1)
y = rand(3,100)
@test_throws ArgumentError CellList(x,box)
cl = CellList(y,box)
@test_throws ArgumentError UpdateCellList!(x, box, cl)
@test_throws ArgumentError CellList(x,y,box)
@test_throws ArgumentError CellList(y,x,box)
cl = CellList(y,y,box)
@test_throws ArgumentError UpdateCellList!(x, y, box, cl)
@test_throws ArgumentError UpdateCellList!(y, x, box, cl)
end
Loading

0 comments on commit 0f6554b

Please sign in to comment.