Skip to content

Commit

Permalink
Merge pull request #4 from JaydevSR/kernel-tiled
Browse files Browse the repository at this point in the history
Tiled Kernel
  • Loading branch information
JaydevSR authored Jun 20, 2023
2 parents 88afacb + 93bb84b commit a533e65
Showing 1 changed file with 34 additions and 31 deletions.
65 changes: 34 additions & 31 deletions src/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# CUDA.jl kernels
const WARPSIZE = UInt32(32)
const FULLMASK = 0xffffffff

function cuda_threads_blocks_pairwise(n_neighbors)
n_threads_gpu = min(n_neighbors, parse(Int, get(ENV, "MOLLY_GPUNTHREADS_PAIRWISE", "512")))
Expand All @@ -18,9 +20,8 @@ function pairwise_force_gpu(coords::AbstractArray{SVector{D, C}}, atoms, boundar

if typeof(nbs) == NoNeighborList
# Use 2D grid with 1D thread blocks
n_atoms = length(atoms)
n_threads_gpu, _ = cuda_threads_blocks_pairwise(length(atoms))
CUDA.@sync @cuda threads=n_threads_gpu blocks=n_atoms pairwise_force_kernel_nonl!(
n_threads_gpu, n_blocks = cuda_threads_blocks_pairwise(length(atoms))
CUDA.@sync @cuda threads=n_threads_gpu blocks=(n_blocks, n_blocks) pairwise_force_kernel_nonl!(
fs_mat, coords, atoms, boundary, pairwise_inters, nbs, Val(D), Val(force_units))
else
n_threads_gpu, n_blocks = cuda_threads_blocks_pairwise(length(nbs))
Expand Down Expand Up @@ -65,48 +66,50 @@ function pairwise_force_kernel_nonl!(forces::AbstractArray{T}, coords_var, atoms
neighbors_var, ::Val{D}, ::Val{F}) where {T, D, F}
coords = CUDA.Const(coords_var)
atoms = CUDA.Const(atoms_var)
n_atoms = length(atoms)

tix = threadIdx().x
i = blockIdx().x
tidx = threadIdx().x
threads = blockDim().x
i_0_block = (blockIdx().x - 1) * threads + 1
j_0_block = (blockIdx().y - 1) * threads + 1
lane = (tidx - 1) % WARPSIZE + 1
warpidx = cld(tidx, WARPSIZE)

# @cushow tidx i_0_block j_0_block lane warpidx threads

forces_shmem = @cuStaticSharedMem(T, (3, 1024))
@inbounds for dim in 1:D
forces_shmem[dim, tix] = zero(T)
@inbounds for dim in 1:3
forces_shmem[dim, tidx] = zero(T)
end

# Calculate forces using a block stride loop
@inbounds if i <= n_atoms
# iterate over horizontal tiles of size warpsize * warpsize to cover all j's
i_0_tile = i_0_block + (warpidx - 1) * WARPSIZE
tilerange = j_0_block:WARPSIZE:j_0_block + threads - 1
for j_0_tile in tilerange # TODO: Ensure i, j in bounds
# Load data on the diagonal
i = i_0_tile + lane - 1
j = j_0_tile + lane - 1
atom_i, coord_i = atoms[i], coords[i]
for j=tix:threads:n_atoms
if j != i
f = sum_pairwise_forces(inters, coord_i, coords[j], atom_i, atoms[j], boundary, false, F)
for dim in 1:D
forces_shmem[dim, tix] += -ustrip(f[dim])
end
end

tilesteps = WARPSIZE
if i_0_tile == j_0_tile # Don't compute i-i forces
j = j % (j_0_tile + WARPSIZE - 1) + 1
tilesteps -= 1
end
end

# Binary tree accumulation
d = Int32(1)
while d < threads
sync_threads()
idx = Int32(2) * d * (tix - Int32(1)) + Int32(1)
@inbounds if idx <= threads && idx + d <= threads
for _ in 1:tilesteps
sync_warp()
atom_j, coord_j = atoms[j], coords[j] # TODO: shuffle this as well
f = sum_pairwise_forces(inters, coord_i, coord_j, atom_i, atom_j, boundary, false, F)
for dim in 1:D
forces_shmem[dim, idx] += forces_shmem[dim, idx+d]
forces_shmem[dim, tidx] += -ustrip(f[dim])
end
j = shfl_sync(FULLMASK, j, lane + 1)
end
d *= Int32(2)
end

# Accumulated force
if tix == 1
@inbounds for dim in 1:D
forces[dim, i] = forces_shmem[dim, 1]
end
sync_warp()
@inbounds for dim in 1:D
Atomix.@atomic :monotonic forces[dim, i_0_block + tidx - 1] += forces_shmem[dim, tidx]
end

return nothing
Expand Down

0 comments on commit a533e65

Please sign in to comment.