Skip to content

Commit

Permalink
Merge pull request #34 from Herb-AI/uniform-solver
Browse files Browse the repository at this point in the history
Move iteration out of the UniformSolver (PR 1/2)
  • Loading branch information
ReubenJ authored Apr 30, 2024
2 parents 7d44ec1 + 3fd7978 commit 8a9c001
Show file tree
Hide file tree
Showing 12 changed files with 171 additions and 256 deletions.
14 changes: 7 additions & 7 deletions src/HerbConstraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ include("solver/generic_solver/state.jl")
include("solver/generic_solver/generic_solver.jl")
include("solver/generic_solver/treemanipulations.jl")

include("solver/fixed_shaped_solver/state_manager.jl")
include("solver/fixed_shaped_solver/state_sparse_set.jl")
include("solver/fixed_shaped_solver/state_fixed_shaped_hole.jl")
include("solver/fixed_shaped_solver/fixed_shaped_solver.jl")
include("solver/fixed_shaped_solver/fixed_shaped_solver_treemanipulations.jl")
include("solver/uniform_solver/state_manager.jl")
include("solver/uniform_solver/state_stack.jl")
include("solver/uniform_solver/state_sparse_set.jl")
include("solver/uniform_solver/state_hole.jl")
include("solver/uniform_solver/uniform_solver.jl")
include("solver/uniform_solver/uniform_treemanipulations.jl")
include("solver/domainutils.jl")

include("patternmatch.jl")
Expand Down Expand Up @@ -127,8 +128,7 @@ export
increment!,
decrement!,

#fixed shaped solver
next_solution!,
#uniform solver
UniformSolver,

#state fixed shaped hole
Expand Down
2 changes: 1 addition & 1 deletion src/solver/generic_solver/treemanipulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ function simplify_hole!(solver::GenericSolver, path::Vector{Int})
for i 1:length(new_node.children)
# try to simplify the new children
child_path = push!(copy(path), i)
if (new_node.children[i] isa Hole)
if (new_node.children[i] isa AbstractHole)
simplify_hole!(solver, child_path)
end
end
Expand Down
File renamed without changes.
File renamed without changes.
69 changes: 69 additions & 0 deletions src/solver/uniform_solver/state_stack.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
Simple stack that can only increase in size.
Supports backtracking by decreasing the size to the saved size.
"""
struct StateStack{T}
vec::Vector{T}
size::StateInt
end

"""
function StateStack{T}(sm::AbstractStateManager) where T
Create an empty StateStack supporting elements of type T
"""
function StateStack{T}(sm::AbstractStateManager) where T
return StateStack{T}(Vector{T}(), StateInt(sm, 0))
end

"""
function StateStack{T}(sm::AbstractStateManager, vec::Vector{T}) where T
Create a StateStack for the provided `vec`
"""
function StateStack{T}(sm::AbstractStateManager, vec::Vector{T}) where T
return StateStack{T}(vec, StateInt(sm, length(vec)))
end

"""
function Base.push!(stack::StateStack, item)
Place an `item` on top of the `stack`.
"""
function Base.push!(stack::StateStack, item)
increment!(stack.size)
if length(stack.vec) > size(stack)
stack.vec[size(stack)] = item
else
push!(stack.vec, item)
end
end

"""
function Base.size(stack::StateStack)
Get the current size of the `stack`.
"""
function Base.size(stack::StateStack)::Int
return get_value(stack.size)
end

"""
function Base.collect(stack::StateStack)
Return the internal `Vector` representation of the `stack`.
!!! warning:
The returned vector is read-only.
"""
function Base.collect(stack::StateStack)
return stack.vec[1:size(stack)]
end

"""
function Base.in(stack::StateStack, value)::Bool
Checks whether the `value` is in the `stack`.
"""
function Base.in(stack::StateStack, value)::Bool
return value stack.vec[1:size(stack)]
end
Original file line number Diff line number Diff line change
@@ -1,48 +1,32 @@
"""
Representation of a branching constraint in the search tree.
"""
struct Branch
hole::StateHole
rule::Int
end

#Shared reference to an empty vector to reduce memory allocations.
NOBRANCHES = Vector{Branch}()

"""
A DFS solver that uses `StateHole`s.
A DFS-based solver that uses `StateHole`s that support backtracking.
"""
mutable struct UniformSolver <: Solver
grammar::AbstractGrammar
sm::StateManager
tree::Union{RuleNode, StateHole}
unvisited_branches::Stack{Vector{Branch}}
path_to_node::Dict{Vector{Int}, AbstractRuleNode}
node_to_path::Dict{AbstractRuleNode, Vector{Int}}
isactive::Dict{AbstractLocalConstraint, StateInt}
canceledconstraints::Set{AbstractLocalConstraint}
nsolutions::Int
isfeasible::Bool
schedule::PriorityQueue{AbstractLocalConstraint, Int}
fix_point_running::Bool
statistics::Union{SolverStatistics, Nothing}
derivation_heuristic
end


"""
UniformSolver(grammar::AbstractGrammar, fixed_shaped_tree::AbstractRuleNode)
"""
function UniformSolver(grammar::AbstractGrammar, fixed_shaped_tree::AbstractRuleNode; with_statistics=false, derivation_heuristic=nothing)
function UniformSolver(grammar::AbstractGrammar, fixed_shaped_tree::AbstractRuleNode; with_statistics=false)
@assert !contains_nonuniform_hole(fixed_shaped_tree) "$(fixed_shaped_tree) contains non-uniform holes"
sm = StateManager()
tree = StateHole(sm, fixed_shaped_tree)
unvisited_branches = Stack{Vector{Branch}}()
path_to_node = Dict{Vector{Int}, AbstractRuleNode}()
node_to_path = Dict{AbstractRuleNode, Vector{Int}}()
isactive = Dict{AbstractLocalConstraint, StateInt}()
canceledconstraints = Set{AbstractLocalConstraint}()
nsolutions = 0
schedule = PriorityQueue{AbstractLocalConstraint, Int}()
fix_point_running = false
statistics = @match with_statistics begin
Expand All @@ -51,13 +35,9 @@ function UniformSolver(grammar::AbstractGrammar, fixed_shaped_tree::AbstractRule
::Nothing => nothing
end
if !isnothing(statistics) statistics.name = "UniformSolver" end
solver = UniformSolver(grammar, sm, tree, unvisited_branches, path_to_node, node_to_path, isactive, canceledconstraints, nsolutions, true, schedule, fix_point_running, statistics, derivation_heuristic)
solver = UniformSolver(grammar, sm, tree, path_to_node, node_to_path, isactive, canceledconstraints, true, schedule, fix_point_running, statistics)
notify_new_nodes(solver, tree, Vector{Int}())
fix_point!(solver)
if isfeasible(solver)
save_state!(solver)
push!(unvisited_branches, generate_branches(solver)) #generate initial branches for the root search node
end
return solver
end

Expand Down Expand Up @@ -144,6 +124,7 @@ function deactivate!(solver::UniformSolver, constraint::AbstractLocalConstraint)
end
if constraint keys(solver.isactive)
# the constraint was posted earlier and should be deactivated
track!(solver.statistics, "deactivate!")
set_value!(solver.isactive[constraint], 0)
return
end
Expand Down Expand Up @@ -173,8 +154,12 @@ function post!(solver::UniformSolver, constraint::AbstractLocalConstraint)
return
end
#if the was not deactivated after initial propagation, it can be added to the list of constraints
@assert constraint keys(solver.isactive) "Attempted to post a constraint that was already posted before"
solver.isactive[constraint] = StateInt(solver.sm, 1)
if (constraint keys(solver.isactive))
@assert solver.isactive[constraint] == 0 "Attempted to post a constraint that is already active"
else
solver.isactive[constraint] = StateInt(solver.sm, 0) #initializing the state int as 0 will deactivate it on backtrack
end
set_value!(solver.isactive[constraint], 1)
end


Expand Down Expand Up @@ -234,114 +219,3 @@ function restore!(solver::UniformSolver)
solver.isfeasible = true
end

#TODO: implement more branching schemes
"""
Returns a vector of disjoint branches to expand the search tree at its current state.
Example:
```
# pseudo code
Hole(domain=[2, 4, 5], children=[
Hole(domain=[1, 6]),
Hole(domain=[1, 6])
])
```
A possible branching scheme could be to be split up in three `Branch`ing constraints:
- `Branch(firsthole, 2)`
- `Branch(firsthole, 4)`
- `Branch(firsthole, 5)`
"""
function generate_branches(solver::UniformSolver)::Vector{Branch}
#omitting `::Vector{Branch}` from `_dfs` speeds up the search by a factor of 2
@assert isfeasible(solver)
function _dfs(node::Union{StateHole, RuleNode}) #::Vector{Branch}
if node isa StateHole && size(node.domain) > 1
#use the derivation_heuristic if the parent_iterator is set up
if isnothing(solver.derivation_heuristic)
return [Branch(node, rule) for rule node.domain]
end
return [Branch(node, rule) for rule solver.derivation_heuristic(findall(node.domain))]
end
for child node.children
branches = _dfs(child)
if !isempty(branches)
return branches
end
end
return NOBRANCHES
end
return _dfs(solver.tree)
end


"""
next_solution!(solver::UniformSolver)::Union{RuleNode, StateHole, Nothing}
Built-in iterator. Search for the next unvisited solution.
Returns nothing if all solutions have been found already.
"""
function next_solution!(solver::UniformSolver)::Union{RuleNode, StateHole, Nothing}
if solver.nsolutions == 1000000 @warn "UniformSolver is iterating over more than 1000000 solutions..." end
if solver.nsolutions > 0
# backtrack from the previous solution
restore!(solver)
end
while length(solver.unvisited_branches) > 0
branches = first(solver.unvisited_branches)
if length(branches) > 0
# current depth has unvisted branches, pick a branch to explore
branch = pop!(branches)
save_state!(solver)
remove_all_but!(solver, solver.node_to_path[branch.hole], branch.rule)
if isfeasible(solver)
# generate new branches for the new search node
branches = generate_branches(solver)
if length(branches) == 0
# search node is a solution leaf node, return the solution
solver.nsolutions += 1
track!(solver.statistics, "#CompleteTrees")
return solver.tree
else
# search node is an (non-root) internal node, store the branches to visit
track!(solver.statistics, "#InternalSearchNodes")
push!(solver.unvisited_branches, branches)
end
else
# search node is an infeasible leaf node, backtrack
track!(solver.statistics, "#InfeasibleTrees")
restore!(solver)
end
else
# search node is an exhausted internal node, backtrack
restore!(solver)
pop!(solver.unvisited_branches)
end
end
if solver.nsolutions == 0 && isfeasible(solver)
_isfilledrecursive(node) = isfilled(node) && all(_isfilledrecursive(c) for c node.children)
if _isfilledrecursive(solver.tree)
# search node is the root and the only solution, return the solution (edgecase)
solver.nsolutions += 1
track!(solver.statistics, "#CompleteTrees")
return solver.tree
end
end
return nothing
end


"""
count_solutions(solver::UniformSolver)
Iterate over all solutions and count the number of solutions encountered.
!!! warning:
Solutions are overwritten. It is not possible to return all the solutions without copying.
"""
function count_solutions(solver::UniformSolver)
count = 0
s = next_solution!(solver)
while !isnothing(s)
count += 1
s = next_solution!(solver)
end
return count
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ using Test
include("test_ordered.jl")
include("test_contains.jl")

include("test_state_stack.jl")
include("test_state_sparse_set.jl")
include("test_state_manager.jl")

include("test_fixed_shaped_solver.jl")
include("test_state_fixed_shaped_hole.jl")
end
Loading

0 comments on commit 8a9c001

Please sign in to comment.