Skip to content

Commit

Permalink
Introduce StateVector (#27)
Browse files Browse the repository at this point in the history
Wraps the state variables returned to the user in a StateVector, but this is also indexed by the cellnumber, hence according to the docs. The key advantage is that this type allows modifying the underlying data by changing a pointer, allowing for easy update of state variables, and this makes it easier to update the state variables if we want to reuse the buffer for multiple simulations, but with different state variables (typical for rve simulations).
  • Loading branch information
KnutAM authored Jul 4, 2024
1 parent 2a565f1 commit 1426b99
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 125 deletions.
30 changes: 15 additions & 15 deletions src/DomainBuffers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,25 +107,23 @@ getset(b::DomainBuffers, domain) = getset(b[domain])
struct DomainBuffer{I,B,S,SDH<:SubDofHandler} <: AbstractDomainBuffer
set::Vector{I}
itembuffer::B
states::Dict{Int,S} # Always indexed by cell. If desired to have 1 state per e.g. facet, need to have e.g. S::Vector{FS}
old_states::Dict{Int,S} # For interfaces, it is possible/likely that state_here and state_there can be given.
states::StateVariables{S}
sdh::SDH
end

struct ThreadedDomainBuffer{I,B,S,SDH<:SubDofHandler} <: AbstractDomainBuffer
chunks::Vector{Vector{Vector{I}}} # I=Int (cell), I=FacetIndex (facet), or
set::Vector{I} # I=NTuple{2,FacetIndex} (interface)
itembuffer::TaskLocals{B,B} # cell, facet, or interface buffer
states::Dict{Int,S} # To be updated during "work"
old_states::Dict{Int,S} # Only for reference
states::StateVariables{S}
sdh::SDH
end
function ThreadedDomainBuffer(set, itembuffer::AbstractItemBuffer, states, old_states, sdh::SubDofHandler, colors_or_chunks=nothing)
function ThreadedDomainBuffer(set, itembuffer::AbstractItemBuffer, states::StateVariables, sdh::SubDofHandler, colors_or_chunks=nothing)
grid = _getgrid(sdh)
set_vector = collect(set)
chunks = create_chunks(grid, set_vector, colors_or_chunks)
itembuffers = TaskLocals(itembuffer)
return ThreadedDomainBuffer(chunks, set_vector, itembuffers, states, old_states, sdh)
return ThreadedDomainBuffer(chunks, set_vector, itembuffers, states, sdh)
end

get_chunks(db::ThreadedDomainBuffer) = db.chunks
Expand All @@ -138,17 +136,15 @@ getset(b::StdDomainBuffer) = b.set

get_dofhandler(b::StdDomainBuffer) = b.sdh.dh
get_itembuffer(b::StdDomainBuffer) = b.itembuffer
get_state(b::StdDomainBuffer, cellnum::Int) = fast_getindex(b.states, cellnum)
get_old_state(b::StdDomainBuffer, cellnum::Int) = fast_getindex(b.old_states, cellnum)
get_state(b::StdDomainBuffer, cellnum::Int) = fast_getindex(b.states.new, cellnum)
get_old_state(b::StdDomainBuffer, cellnum::Int) = fast_getindex(b.states.old, cellnum)

get_state(b::StdDomainBuffer) = b.states
get_old_state(b::StdDomainBuffer) = b.old_states
get_state(b::StdDomainBuffer) = b.states.new
get_old_state(b::StdDomainBuffer) = b.states.old
get_material(b::StdDomainBuffer) = get_material(get_base(get_itembuffer(b)))

# Update old_states = states after convergence
function update_states!(b::AbstractDomainBuffer)
update_states!(get_old_state(b), get_state(b))
end
# Update old_states = new_states after convergence
update_states!(b::StdDomainBuffer) = update_states!(b.states)

function set_time_increment!(b::StdDomainBuffer, Δt)
set_time_increment!(get_base(get_itembuffer(b)), Δt)
Expand All @@ -163,4 +159,8 @@ function replace_material(db::ThreadedDomainBuffer, replacement_function)
task_ibuf = map(ibuf->_replace_material(ibuf, replacement_function), get_locals(db.itembuffer))
itembuffer = TaskLocals(base_ibuf, task_ibuf)
return _replace_field(db, Val(:itembuffer), itembuffer)
end
end

# Experimental: Insert new states, allows reusing the buffer for multiple simulations with same
# initial state (grid, dh, etc.), but which experience different loading. Typically for RVE simulations.
replace_states!(db::StdDomainBuffer, states::StateVariables) = replace_states!(db.states, states)
14 changes: 7 additions & 7 deletions src/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,21 @@ function setup_itembuffer(adb, domain::DomainSpec{Int}, states)
end

function _setup_domainbuffer(threaded, domain; a=nothing, autodiffbuffer=Val(false))
states = create_states(domain, a)
new_states = create_states(domain, a)
old_states = create_states(domain, a)
itembuffer = setup_itembuffer(autodiffbuffer, domain, states)
return _setup_domainbuffer(threaded, domain.set, itembuffer, states, old_states, domain.sdh, domain.colors_or_chunks)
itembuffer = setup_itembuffer(autodiffbuffer, domain, new_states)
return _setup_domainbuffer(threaded, domain.set, itembuffer, StateVariables(old_states, new_states), domain.sdh, domain.colors_or_chunks)
end

# Type-unstable switch
function _setup_domainbuffer(threaded::Bool, args...)
return _setup_domainbuffer(Val(threaded), args...)
end
# Sequential
function _setup_domainbuffer(::Val{false}, set, itembuffer, states, old_states, sdh, args...)
return DomainBuffer(set, itembuffer, states, old_states, sdh)
function _setup_domainbuffer(::Val{false}, set, itembuffer, states, sdh, args...)
return DomainBuffer(set, itembuffer, states, sdh)
end
# Threaded
function _setup_domainbuffer(::Val{true}, set, itembuffer, states, old_states, sdh, args...)
return ThreadedDomainBuffer(set, itembuffer, states, old_states, sdh, args...)
function _setup_domainbuffer(::Val{true}, set, itembuffer, states, sdh, args...)
return ThreadedDomainBuffer(set, itembuffer, states, sdh, args...)
end
49 changes: 27 additions & 22 deletions src/states.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,30 @@
# Minimal interface for a vector, storage format will probably be updated later.
mutable struct StateVector{SV}
vals::Dict{Int, SV}
end
Base.getindex(s::StateVector, cellnum::Int) = s.vals[cellnum]
Base.setindex!(s::StateVector, v, cellnum::Int) = setindex!(s.vals, v, cellnum)
Base.:(==)(a::StateVector, b::StateVector) = (a.vals == b.vals)

struct StateVariables{SV}
old::StateVector{SV} # Rule: Referenced during assembly, not changed (ever)
new::StateVector{SV} # Rule: Updated during assembly, not referenced (before updated)
end
StateVariables(old::Dict, new::Dict) = StateVariables(StateVector(old), StateVector(new))

function update_states!(sv::StateVariables)
tmp = sv.old.vals
sv.old.vals = sv.new.vals
sv.new.vals = tmp
end

# Experimental, basically copy!, but use separate name for clarity
function replace_states!(dst::StateVariables, src::StateVariables)
dst.old.vals = src.old.vals
dst.new.vals = dst.new.vals
return dst
end

"""
create_cell_state(material, cellvalues, x, ae, dofrange)
Expand Down Expand Up @@ -40,25 +67,3 @@ function create_states(sdh::SubDofHandler, material, cellvalues, a, cellset, dof
dofs = zeros(Int, ndofs_per_cell(sdh))
return Dict(cellnr => _create_cell_state(coords, dofs, material, cellvalues, a, ae, dofrange, sdh, cellnr) for cellnr in cellset)
end

"""
update_states!(old_states, states)
In most cases, this 2-argument function is not required, and the
entire domain buffer can be passed instead, see
[`update_states!`](@ref update_states!(::FerriteAssembly.DomainBuffers)).
This 2-argument function will then be called for the stored state variables.
"""
function update_states!(old_states::T, states::T) where T<:Dict{Int,<:Vector}
for (key, new_s) in states
update_states!(old_states[key], new_s)
end
end
function update_states!(::T, ::T) where T<:Union{Vector{Nothing},Dict{Int,Nothing},Dict{Int,Vector{Nothing}}}
return nothing
end
@inline function update_states!(old_states::T, states::T) where T<:Union{Vector{ET},Dict{Int,ET}} where ET
copy_states!(Val(isbitstype(ET)), old_states, states)
end
@inline copy_states!(::Val{true}, old_states, states) = copy!(old_states,states)
@inline copy_states!(::Val{false}, old_states, states) = copy!(old_states,deepcopy(states))
10 changes: 5 additions & 5 deletions test/heatequation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@
ad2 = DomainSpec(dh, material["B"], cv; set=setB)
buffer = setup_domainbuffers(Dict("A"=>ad1, "B"=>ad2); autodiffbuffer=autodiff_cb, threading=threaded)
@test isa(buffer, Dict{String,<:BufferType})
@test isa(FerriteAssembly.get_old_state(buffer, "A"), Dict{Int})
@test isa(FerriteAssembly.get_old_state(buffer, "B"), Dict{Int})
@test isa(FerriteAssembly.get_old_state(buffer, "A"), FerriteAssembly.StateVector)
@test isa(FerriteAssembly.get_old_state(buffer, "B"), FerriteAssembly.StateVector)
return buffer
elseif isa(material, Dict) && length(dh.subdofhandlers) > 1
sdh1 = dh.subdofhandlers[1]
Expand All @@ -151,7 +151,7 @@
ad4 = DomainSpec(sdh2, material["B"], cv; set=setB) # sdh2B
buffer = setup_domainbuffers(Dict("sdh1A"=>ad1, "sdh1B"=>ad2, "sdh2A"=>ad3, "sdh2B"=>ad4); autodiffbuffer=autodiff_cb, threading=threaded)
@test isa(buffer, Dict{String,<:BufferType})
@test isa(FerriteAssembly.get_old_state(buffer, "sdh1A"), Dict{Int})
@test isa(FerriteAssembly.get_old_state(buffer, "sdh1A"), FerriteAssembly.StateVector)
return buffer
elseif length(dh.subdofhandlers) > 1
sdh1 = dh.subdofhandlers[1]
Expand All @@ -161,12 +161,12 @@
ad2 = DomainSpec(sdh2, material, cv; set=set2)
buffer = setup_domainbuffers(Dict("sdh1"=>ad1, "sdh2"=>ad2); autodiffbuffer=autodiff_cb, threading=threaded)
@test isa(buffer, Dict{String,<:BufferType})
@test isa(FerriteAssembly.get_old_state(buffer, "sdh1"), Dict{Int})
@test isa(FerriteAssembly.get_old_state(buffer, "sdh1"), FerriteAssembly.StateVector)
return buffer
else
buffer = setup_domainbuffer(DomainSpec(dh, material, cv); autodiffbuffer=autodiff_cb, threading=threaded)
@test isa(buffer, BufferType)
@test isa(FerriteAssembly.get_old_state(buffer), Dict{Int})
@test isa(FerriteAssembly.get_old_state(buffer), FerriteAssembly.StateVector)
return buffer
end
end
Expand Down
155 changes: 79 additions & 76 deletions test/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,88 +50,91 @@ end
import .TestStateModule: MatA, MatB, MatC, StateA, StateB, StateC

for (CT, Dim) in ((Line, 1), (QuadraticTriangle, 2), (Hexahedron, 3))
grid = generate_grid(CT, ntuple(_->3, Dim))
ip = geometric_interpolation(CT)
dh = DofHandler(grid); add!(dh, :u, ip^Dim); close!(dh);
@testset "$CT" begin
grid = generate_grid(CT, ntuple(_->3, Dim))
ip = geometric_interpolation(CT)
dh = DofHandler(grid); add!(dh, :u, ip^Dim); close!(dh);

K = allocate_matrix(dh)
r = zeros(ndofs(dh));
kr_assembler = start_assemble(K, r)
r_assembler = FerriteAssembly.ReAssembler(r)
a = copy(r)
RefShape = Ferrite.getrefshape(ip)
cv = CellValues(QuadratureRule{RefShape}(2), ip^Dim, ip);
K = allocate_matrix(dh)
r = zeros(ndofs(dh));
kr_assembler = start_assemble(K, r)
r_assembler = FerriteAssembly.ReAssembler(r)
a = copy(r)
RefShape = Ferrite.getrefshape(ip)
cv = CellValues(QuadratureRule{RefShape}(2), ip^Dim, ip);

# MatA:
# - Check correct values before and after update
# - Check unaliased old and new after update_states!
buffer = setup_domainbuffer(DomainSpec(dh, MatA(), cv))
states = FerriteAssembly.get_state(buffer)
old_states = FerriteAssembly.get_old_state(buffer)
@test isa(old_states, Dict{Int,Vector{StateA}})
@test old_states == states
@test old_states[1] == [StateA(-1, 0) for _ in 1:getnquadpoints(cv)]
work!(r_assembler, buffer)
@test old_states[1] == [StateA(-1, 0) for _ in 1:getnquadpoints(cv)] # Unchanged
for cellnr in 1:getncells(grid)
@test states[cellnr] == [StateA(cellnr, i) for i in 1:getnquadpoints(cv)] # Updated
end
update_states!(buffer)
@test old_states == states # Correctly updated values
states[1][1] = StateA(0,0)
@test old_states[1][1] == StateA(1,1) # But not aliased
allocs = @allocated update_states!(buffer)
@test allocs == 0 # Vector{T} where isbitstype(T) should not allocate (MatA fulfills this)
# MatA:
# - Check correct values before and after update
# - Check unaliased old and new after update_states!
buffer = setup_domainbuffer(DomainSpec(dh, MatA(), cv))
states = FerriteAssembly.get_state(buffer)
old_states = FerriteAssembly.get_old_state(buffer)
@test isa(old_states, FerriteAssembly.StateVector{Vector{StateA}})
@test old_states == states
@test old_states[1] == [StateA(-1, 0) for _ in 1:getnquadpoints(cv)]
work!(r_assembler, buffer)
@test old_states[1] == [StateA(-1, 0) for _ in 1:getnquadpoints(cv)] # Unchanged
for cellnr in 1:getncells(grid)
@test states[cellnr] == [StateA(cellnr, i) for i in 1:getnquadpoints(cv)] # Updated
end
states_dc = deepcopy(states) # Allowed to update states during update_states!
update_states!(buffer)
@test old_states == states_dc # Correctly updated values
states[1][1] = StateA(0,0)
@test old_states[1][1] == StateA(1,1) # But not aliased
allocs = @allocated update_states!(buffer)
@test allocs == 0 # Vector{T} where isbitstype(T) should not allocate (MatA fulfills this)

# MatB (not bitstype)
# - Check correct values before and after update
# - Check unaliased old and new after update_states!
buffer = setup_domainbuffer(DomainSpec(dh, MatB{Dim}(), cv))
states = FerriteAssembly.get_state(buffer)
old_states = FerriteAssembly.get_old_state(buffer)
@test isa(old_states, Dict{Int,StateB{Dim}})
@test old_states == states
@test old_states[1] == StateB(-1, [zero(Vec{Dim}) for i in 1:getnquadpoints(cv)])
work!(kr_assembler, buffer)
@test old_states[1] == StateB(-1, [zero(Vec{Dim}) for i in 1:getnquadpoints(cv)]) # Unchanged
for cellnr in 1:getncells(grid)
# MatB (not bitstype)
# - Check correct values before and after update
# - Check unaliased old and new after update_states!
buffer = setup_domainbuffer(DomainSpec(dh, MatB{Dim}(), cv))
states = FerriteAssembly.get_state(buffer)
old_states = FerriteAssembly.get_old_state(buffer)
@test isa(old_states, FerriteAssembly.StateVector{StateB{Dim}})
@test old_states == states
@test old_states[1] == StateB(-1, [zero(Vec{Dim}) for i in 1:getnquadpoints(cv)])
work!(kr_assembler, buffer)
@test old_states[1] == StateB(-1, [zero(Vec{Dim}) for i in 1:getnquadpoints(cv)]) # Unchanged
for cellnr in 1:getncells(grid)
coords = getcoordinates(grid, cellnr)
x_values = [spatial_coordinate(cv, i, coords) for i in 1:getnquadpoints(cv)]
@test states[cellnr] == StateB(cellnr, x_values) # Updated
end
states_dc = deepcopy(states) # Allowed to update states during update_states!
update_states!(buffer)
@test old_states == states_dc # Correctly updated values
cellnr = rand(1:getncells(grid))
coords = getcoordinates(grid, cellnr)
x_values = [spatial_coordinate(cv, i, coords) for i in 1:getnquadpoints(cv)]
@test states[cellnr] == StateB(cellnr, x_values) # Updated
end
update_states!(buffer)
@test old_states == states # Correctly updated values
cellnr = rand(1:getncells(grid))
coords = getcoordinates(grid, cellnr)
x_values = [spatial_coordinate(cv, i, coords) for i in 1:getnquadpoints(cv)]
states[cellnr] = StateB(0, -x_values)
@test old_states[cellnr] == StateB(cellnr, x_values) # But not aliased
allocs = @allocated update_states!(buffer)
@test allocs > 0 # Vector{T} where !isbitstype(T) is expected to allocate.
# If this fails after improvements, that is just good, but then the docstring should be updated.
states[cellnr] = StateB(0, -x_values)
@test old_states[cellnr] == StateB(cellnr, x_values) # But not aliased
allocs = @allocated update_states!(buffer)
@test allocs == 0 # Vector{T} where !isbitstype(T) should no longer allocate

# MatC (accumulation), using threading as well
colors = create_coloring(grid)
buffer = setup_domainbuffer(DomainSpec(dh, MatC(), cv; colors=colors))
states = FerriteAssembly.get_state(buffer)
old_states = FerriteAssembly.get_old_state(buffer)
@test isa(old_states, Dict{Int,Vector{StateC}})
@test old_states == states
@test old_states[1][1] == StateC(0)
work!(kr_assembler, buffer)
@test old_states[1][1] == StateC(0)
@test states[1][1] == StateC(1) # Added 1
update_states!(buffer)
@test old_states[1][1] == StateC(1) # Updated
states[1][1] = StateC(0) # Set states to zero to test aliasing and update in next assembly
@test old_states[1][1] == StateC(1) # But not aliased
work!(kr_assembler, buffer)
@test states[1][1] == StateC(2) # Added 1 from old_states[1][1] (not from states[1][1] which was StateC(0))
for cellnr in 1:getncells(grid)
@test states[cellnr][2] == StateC(2) # Check that all are updated
# MatC (accumulation), using threading as well
colors = create_coloring(grid)
buffer = setup_domainbuffer(DomainSpec(dh, MatC(), cv; colors=colors))
states = FerriteAssembly.get_state(buffer)
old_states = FerriteAssembly.get_old_state(buffer)
@test isa(old_states, FerriteAssembly.StateVector{Vector{StateC}})
@test old_states == states
@test old_states[1][1] == StateC(0)
work!(kr_assembler, buffer)
@test old_states[1][1] == StateC(0)
@test states[1][1] == StateC(1) # Added 1
update_states!(buffer)
@test old_states[1][1] == StateC(1) # Updated
states[1][1] = StateC(0) # Set states to zero to test aliasing and update in next assembly
@test old_states[1][1] == StateC(1) # But not aliased
work!(kr_assembler, buffer)
@test states[1][1] == StateC(2) # Added 1 from old_states[1][1] (not from states[1][1] which was StateC(0))
for cellnr in 1:getncells(grid)
@test states[cellnr][2] == StateC(2) # Check that all are updated
end
allocs = @allocated update_states!(buffer)
@test allocs == 0 # Vector{T} where isbitstype(T) should not allocate (MatC fulfills this)
end
allocs = @allocated update_states!(buffer)
@test allocs == 0 # Vector{T} where isbitstype(T) should not allocate (MatC fulfills this)
end

ip = Lagrange{RefTriangle,1}()
Expand All @@ -140,7 +143,7 @@ end
# Smoke-test of update_states! for nothing states (and check no allocations)
cv = CellValues(QuadratureRule{RefTriangle}(2), ip)
buffer = setup_domainbuffer(DomainSpec(dh, nothing, cv))
@test isa(FerriteAssembly.get_state(buffer), Dict{Int,Nothing})
@test isa(FerriteAssembly.get_state(buffer), FerriteAssembly.StateVector{Nothing})
update_states!(buffer) # Compile
allocs = @allocated update_states!(buffer)
@test allocs == 0
Expand Down

0 comments on commit 1426b99

Please sign in to comment.