Skip to content

Commit

Permalink
unroll GBn2 Born radii loop
Browse files Browse the repository at this point in the history
  • Loading branch information
jgreener64 committed Oct 30, 2024
1 parent ba2dd61 commit 9db6db3
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions src/interactions/implicit_solvent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -709,21 +709,13 @@ function born_radii_and_grad(inter::ImplicitSolventOBC, coords::CuArray, boundar
return Bs, B_grads, I_grads
end

struct BornRadiiGBN2LoopResult{I, G}
I::I
I_grad::G
end

get_I(r::BornRadiiGBN2LoopResult) = r.I
get_I_grad(r::BornRadiiGBN2LoopResult) = r.I_grad

function born_radii_loop_GBN2(coord_i::SVector{D, C}, coord_j, ori, orj, srj, dist_cutoff,
offset, neck_scale, neck_cut, d0, m0, boundary) where {D, C}
I = zero(coord_i[1] / unit(dist_cutoff)^2)
I_grad = zero(coord_i[1] / unit(dist_cutoff)^3)
r = norm(vector(coord_i, coord_j, boundary))
if iszero_value(r) || (!iszero_value(dist_cutoff) && r > dist_cutoff)
return BornRadiiGBN2LoopResult(I, I_grad)
return I, I_grad
end
U = r + srj
if ori < U
Expand All @@ -747,17 +739,25 @@ function born_radii_loop_GBN2(coord_i::SVector{D, C}, coord_j, ori, orj, srj, di
numer = 2 * r_d0_strip + 9 * r_d0_strip^5 / 5
I_grad -= 10 * neck_scale * m0 * numer / (denom^2 * unit(dist_cutoff))
end
return BornRadiiGBN2LoopResult(I, I_grad)
return I, I_grad
end

function born_radii_and_grad(inter::ImplicitSolventGBN2, coords, boundary)
coords_i = @view coords[inter.is]
coords_j = @view coords[inter.js]
loop_res = born_radii_loop_GBN2.(coords_i, coords_j, inter.oris, inter.orjs, inter.srjs,
inter.dist_cutoff, inter.offset, inter.neck_scale,
inter.neck_cut, inter.d0s, inter.m0s, (boundary,))
Is = dropdims(sum(get_I.(loop_res); dims=2); dims=2)
I_grads = get_I_grad.(loop_res)
function born_radii_and_grad(inter::ImplicitSolventGBN2{T}, coords, boundary) where T
Is = fill(zero(T) / unit(inter.dist_cutoff), length(coords))
I_grads = zeros(eltype(Is), length(Is), length(Is)) ./ unit(inter.dist_cutoff)
@inbounds for i in eachindex(coords)
I_sum = zero(eltype(Is))
for j in eachindex(coords)
I, I_grad = born_radii_loop_GBN2(
coords[i], coords[j], inter.oris[i], inter.orjs[j], inter.srjs[j],
inter.dist_cutoff, inter.offset, inter.neck_scale, inter.neck_cut,
inter.d0s[i, j], inter.m0s[i, j], boundary,
)
I_sum += I
I_grads[i, j] = I_grad
end
Is[i] = I_sum
end

Bs_B_grads = born_radii_sum.(inter.offset_radii, inter.offset, Is,
inter.αs, inter.βs, inter.γs)
Expand Down

0 comments on commit 9db6db3

Please sign in to comment.