diff --git a/src/HerbConstraints.jl b/src/HerbConstraints.jl index 555bd49..dc39fba 100644 --- a/src/HerbConstraints.jl +++ b/src/HerbConstraints.jl @@ -47,11 +47,12 @@ include("solver/generic_solver/state.jl") include("solver/generic_solver/generic_solver.jl") include("solver/generic_solver/treemanipulations.jl") -include("solver/fixed_shaped_solver/state_manager.jl") -include("solver/fixed_shaped_solver/state_sparse_set.jl") -include("solver/fixed_shaped_solver/state_fixed_shaped_hole.jl") -include("solver/fixed_shaped_solver/fixed_shaped_solver.jl") -include("solver/fixed_shaped_solver/fixed_shaped_solver_treemanipulations.jl") +include("solver/uniform_solver/state_manager.jl") +include("solver/uniform_solver/state_stack.jl") +include("solver/uniform_solver/state_sparse_set.jl") +include("solver/uniform_solver/state_hole.jl") +include("solver/uniform_solver/uniform_solver.jl") +include("solver/uniform_solver/uniform_treemanipulations.jl") include("solver/domainutils.jl") include("patternmatch.jl") @@ -126,8 +127,7 @@ export increment!, decrement!, - #fixed shaped solver - next_solution!, + #uniform solver UniformSolver, #state fixed shaped hole diff --git a/src/solver/generic_solver/treemanipulations.jl b/src/solver/generic_solver/treemanipulations.jl index 4b76d50..b8dd511 100644 --- a/src/solver/generic_solver/treemanipulations.jl +++ b/src/solver/generic_solver/treemanipulations.jl @@ -250,7 +250,7 @@ function simplify_hole!(solver::GenericSolver, path::Vector{Int}) for i ∈ 1:length(new_node.children) # try to simplify the new children child_path = push!(copy(path), i) - if (new_node.children[i] isa Hole) + if (new_node.children[i] isa AbstractHole) simplify_hole!(solver, child_path) end end diff --git a/src/solver/fixed_shaped_solver/state_fixed_shaped_hole.jl b/src/solver/uniform_solver/state_hole.jl similarity index 100% rename from src/solver/fixed_shaped_solver/state_fixed_shaped_hole.jl rename to src/solver/uniform_solver/state_hole.jl diff --git a/src/solver/fixed_shaped_solver/state_manager.jl b/src/solver/uniform_solver/state_manager.jl similarity index 100% rename from src/solver/fixed_shaped_solver/state_manager.jl rename to src/solver/uniform_solver/state_manager.jl diff --git a/src/solver/fixed_shaped_solver/state_sparse_set.jl b/src/solver/uniform_solver/state_sparse_set.jl similarity index 100% rename from src/solver/fixed_shaped_solver/state_sparse_set.jl rename to src/solver/uniform_solver/state_sparse_set.jl diff --git a/src/solver/uniform_solver/state_stack.jl b/src/solver/uniform_solver/state_stack.jl new file mode 100644 index 0000000..df6be52 --- /dev/null +++ b/src/solver/uniform_solver/state_stack.jl @@ -0,0 +1,69 @@ +""" +Simple stack that can only increase in size. +Supports backtracking by decreasing the size to the saved size. +""" +struct StateStack{T} + vec::Vector{T} + size::StateInt +end + +""" + function StateStack{T}(sm::AbstractStateManager) where T + +Create an empty StateStack supporting elements of type T +""" +function StateStack{T}(sm::AbstractStateManager) where T + return StateStack{T}(Vector{T}(), StateInt(sm, 0)) +end + +""" + function StateStack{T}(sm::AbstractStateManager, vec::Vector{T}) where T + +Create a StateStack for the provided `vec` +""" +function StateStack{T}(sm::AbstractStateManager, vec::Vector{T}) where T + return StateStack{T}(vec, StateInt(sm, length(vec))) +end + +""" + function Base.push!(stack::StateStack, item) + +Place an `item` on top of the `stack`. +""" +function Base.push!(stack::StateStack, item) + increment!(stack.size) + if length(stack.vec) > size(stack) + stack.vec[size(stack)] = item + else + push!(stack.vec, item) + end +end + +""" + function Base.size(stack::StateStack) + +Get the current size of the `stack`. +""" +function Base.size(stack::StateStack)::Int + return get_value(stack.size) +end + +""" + function Base.collect(stack::StateStack) + +Return the internal `Vector` representation of the `stack`. +!!! warning: + The returned vector is read-only. +""" +function Base.collect(stack::StateStack) + return stack.vec[1:size(stack)] +end + +""" + function Base.in(stack::StateStack, value)::Bool + +Checks whether the `value` is in the `stack`. +""" +function Base.in(stack::StateStack, value)::Bool + return value ∈ stack.vec[1:size(stack)] +end diff --git a/src/solver/fixed_shaped_solver/fixed_shaped_solver.jl b/src/solver/uniform_solver/uniform_solver.jl similarity index 56% rename from src/solver/fixed_shaped_solver/fixed_shaped_solver.jl rename to src/solver/uniform_solver/uniform_solver.jl index 5dfc7d1..7c83435 100644 --- a/src/solver/fixed_shaped_solver/fixed_shaped_solver.jl +++ b/src/solver/uniform_solver/uniform_solver.jl @@ -1,48 +1,32 @@ """ -Representation of a branching constraint in the search tree. -""" -struct Branch - hole::StateHole - rule::Int -end - -#Shared reference to an empty vector to reduce memory allocations. -NOBRANCHES = Vector{Branch}() - -""" -A DFS solver that uses `StateHole`s. +A DFS-based solver that uses `StateHole`s that support backtracking. """ mutable struct UniformSolver <: Solver grammar::AbstractGrammar sm::StateManager tree::Union{RuleNode, StateHole} - unvisited_branches::Stack{Vector{Branch}} path_to_node::Dict{Vector{Int}, AbstractRuleNode} node_to_path::Dict{AbstractRuleNode, Vector{Int}} isactive::Dict{AbstractLocalConstraint, StateInt} canceledconstraints::Set{AbstractLocalConstraint} - nsolutions::Int isfeasible::Bool schedule::PriorityQueue{AbstractLocalConstraint, Int} fix_point_running::Bool statistics::Union{SolverStatistics, Nothing} - derivation_heuristic end """ UniformSolver(grammar::AbstractGrammar, fixed_shaped_tree::AbstractRuleNode) """ -function UniformSolver(grammar::AbstractGrammar, fixed_shaped_tree::AbstractRuleNode; with_statistics=false, derivation_heuristic=nothing) +function UniformSolver(grammar::AbstractGrammar, fixed_shaped_tree::AbstractRuleNode; with_statistics=false) @assert !contains_nonuniform_hole(fixed_shaped_tree) "$(fixed_shaped_tree) contains non-uniform holes" sm = StateManager() tree = StateHole(sm, fixed_shaped_tree) - unvisited_branches = Stack{Vector{Branch}}() path_to_node = Dict{Vector{Int}, AbstractRuleNode}() node_to_path = Dict{AbstractRuleNode, Vector{Int}}() isactive = Dict{AbstractLocalConstraint, StateInt}() canceledconstraints = Set{AbstractLocalConstraint}() - nsolutions = 0 schedule = PriorityQueue{AbstractLocalConstraint, Int}() fix_point_running = false statistics = @match with_statistics begin @@ -51,13 +35,9 @@ function UniformSolver(grammar::AbstractGrammar, fixed_shaped_tree::AbstractRule ::Nothing => nothing end if !isnothing(statistics) statistics.name = "UniformSolver" end - solver = UniformSolver(grammar, sm, tree, unvisited_branches, path_to_node, node_to_path, isactive, canceledconstraints, nsolutions, true, schedule, fix_point_running, statistics, derivation_heuristic) + solver = UniformSolver(grammar, sm, tree, path_to_node, node_to_path, isactive, canceledconstraints, true, schedule, fix_point_running, statistics) notify_new_nodes(solver, tree, Vector{Int}()) fix_point!(solver) - if isfeasible(solver) - save_state!(solver) - push!(unvisited_branches, generate_branches(solver)) #generate initial branches for the root search node - end return solver end @@ -144,6 +124,7 @@ function deactivate!(solver::UniformSolver, constraint::AbstractLocalConstraint) end if constraint ∈ keys(solver.isactive) # the constraint was posted earlier and should be deactivated + track!(solver.statistics, "deactivate!") set_value!(solver.isactive[constraint], 0) return end @@ -173,8 +154,12 @@ function post!(solver::UniformSolver, constraint::AbstractLocalConstraint) return end #if the was not deactivated after initial propagation, it can be added to the list of constraints - @assert constraint ∉ keys(solver.isactive) "Attempted to post a constraint that was already posted before" - solver.isactive[constraint] = StateInt(solver.sm, 1) + if (constraint ∈ keys(solver.isactive)) + @assert solver.isactive[constraint] == 0 "Attempted to post a constraint that is already active" + else + solver.isactive[constraint] = StateInt(solver.sm, 0) #initializing the state int as 0 will deactivate it on backtrack + end + set_value!(solver.isactive[constraint], 1) end @@ -234,114 +219,3 @@ function restore!(solver::UniformSolver) solver.isfeasible = true end -#TODO: implement more branching schemes -""" -Returns a vector of disjoint branches to expand the search tree at its current state. -Example: -``` -# pseudo code -Hole(domain=[2, 4, 5], children=[ - Hole(domain=[1, 6]), - Hole(domain=[1, 6]) -]) -``` -A possible branching scheme could be to be split up in three `Branch`ing constraints: -- `Branch(firsthole, 2)` -- `Branch(firsthole, 4)` -- `Branch(firsthole, 5)` -""" -function generate_branches(solver::UniformSolver)::Vector{Branch} - #omitting `::Vector{Branch}` from `_dfs` speeds up the search by a factor of 2 - @assert isfeasible(solver) - function _dfs(node::Union{StateHole, RuleNode}) #::Vector{Branch} - if node isa StateHole && size(node.domain) > 1 - #use the derivation_heuristic if the parent_iterator is set up - if isnothing(solver.derivation_heuristic) - return [Branch(node, rule) for rule ∈ node.domain] - end - return [Branch(node, rule) for rule ∈ solver.derivation_heuristic(findall(node.domain))] - end - for child ∈ node.children - branches = _dfs(child) - if !isempty(branches) - return branches - end - end - return NOBRANCHES - end - return _dfs(solver.tree) -end - - -""" - next_solution!(solver::UniformSolver)::Union{RuleNode, StateHole, Nothing} - -Built-in iterator. Search for the next unvisited solution. -Returns nothing if all solutions have been found already. -""" -function next_solution!(solver::UniformSolver)::Union{RuleNode, StateHole, Nothing} - if solver.nsolutions == 1000000 @warn "UniformSolver is iterating over more than 1000000 solutions..." end - if solver.nsolutions > 0 - # backtrack from the previous solution - restore!(solver) - end - while length(solver.unvisited_branches) > 0 - branches = first(solver.unvisited_branches) - if length(branches) > 0 - # current depth has unvisted branches, pick a branch to explore - branch = pop!(branches) - save_state!(solver) - remove_all_but!(solver, solver.node_to_path[branch.hole], branch.rule) - if isfeasible(solver) - # generate new branches for the new search node - branches = generate_branches(solver) - if length(branches) == 0 - # search node is a solution leaf node, return the solution - solver.nsolutions += 1 - track!(solver.statistics, "#CompleteTrees") - return solver.tree - else - # search node is an (non-root) internal node, store the branches to visit - track!(solver.statistics, "#InternalSearchNodes") - push!(solver.unvisited_branches, branches) - end - else - # search node is an infeasible leaf node, backtrack - track!(solver.statistics, "#InfeasibleTrees") - restore!(solver) - end - else - # search node is an exhausted internal node, backtrack - restore!(solver) - pop!(solver.unvisited_branches) - end - end - if solver.nsolutions == 0 && isfeasible(solver) - _isfilledrecursive(node) = isfilled(node) && all(_isfilledrecursive(c) for c ∈ node.children) - if _isfilledrecursive(solver.tree) - # search node is the root and the only solution, return the solution (edgecase) - solver.nsolutions += 1 - track!(solver.statistics, "#CompleteTrees") - return solver.tree - end - end - return nothing -end - - -""" - count_solutions(solver::UniformSolver) - -Iterate over all solutions and count the number of solutions encountered. -!!! warning: - Solutions are overwritten. It is not possible to return all the solutions without copying. -""" -function count_solutions(solver::UniformSolver) - count = 0 - s = next_solution!(solver) - while !isnothing(s) - count += 1 - s = next_solution!(solver) - end - return count -end diff --git a/src/solver/fixed_shaped_solver/fixed_shaped_solver_treemanipulations.jl b/src/solver/uniform_solver/uniform_treemanipulations.jl similarity index 100% rename from src/solver/fixed_shaped_solver/fixed_shaped_solver_treemanipulations.jl rename to src/solver/uniform_solver/uniform_treemanipulations.jl diff --git a/test/runtests.jl b/test/runtests.jl index 73e8539..7d6cde2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,9 +14,9 @@ using Test include("test_ordered.jl") include("test_contains.jl") + include("test_state_stack.jl") include("test_state_sparse_set.jl") include("test_state_manager.jl") - include("test_fixed_shaped_solver.jl") include("test_state_fixed_shaped_hole.jl") end diff --git a/test/test_fixed_shaped_solver.jl b/test/test_fixed_shaped_solver.jl deleted file mode 100644 index 1cd4234..0000000 --- a/test/test_fixed_shaped_solver.jl +++ /dev/null @@ -1,111 +0,0 @@ -@testset verbose=true "UniformSolver" begin - - function create_dummy_grammar_and_tree_128programs() - grammar = @csgrammar begin - Number = Number + Number - Number = Number - Number - Number = Number * Number - Number = Number / Number - Number = x | 1 | 2 | 3 - end - - fixed_shaped_tree = RuleNode(1, [ - UniformHole(BitVector((1, 1, 1, 1, 0, 0, 0, 0)), [ - UniformHole(BitVector((0, 0, 0, 0, 1, 1, 1, 1)), []) - UniformHole(BitVector((0, 0, 0, 0, 1, 0, 0, 1)), []) - ]), - UniformHole(BitVector((0, 0, 0, 0, 1, 1, 1, 1)), []) - ]) - # 4 * 4 * 2 * 4 = 128 programs without constraints - - return grammar, fixed_shaped_tree - end - - @testset "Without constraints" begin - grammar, fixed_shaped_tree = create_dummy_grammar_and_tree_128programs() - fixed_shaped_solver = UniformSolver(grammar, fixed_shaped_tree) - @test HerbConstraints.count_solutions(fixed_shaped_solver) == 128 - end - - @testset "Forbidden constraint" begin - #forbid "a - a" - grammar, fixed_shaped_tree = create_dummy_grammar_and_tree_128programs() - addconstraint!(grammar, Forbidden(RuleNode(2, [VarNode(:a), VarNode(:a)]))) - fixed_shaped_solver = UniformSolver(grammar, fixed_shaped_tree) - @test HerbConstraints.count_solutions(fixed_shaped_solver) == 120 - - #forbid all rulenodes - grammar, fixed_shaped_tree = create_dummy_grammar_and_tree_128programs() - addconstraint!(grammar, Forbidden(VarNode(:a))) - fixed_shaped_solver = UniformSolver(grammar, fixed_shaped_tree) - @test HerbConstraints.count_solutions(fixed_shaped_solver) == 0 - end - - @testset "The root is the only solution" begin - grammar = @csgrammar begin - S = 1 - end - - solver = UniformSolver(grammar, RuleNode(1)) - @test next_solution!(solver) == RuleNode(1) - @test isnothing(next_solution!(solver)) - end - - @testset "No solutions (ordered constraint)" begin - grammar = @csgrammar begin - Number = 1 - Number = x - Number = Number + Number - Number = Number - Number - end - constraint1 = Ordered(RuleNode(3, [ - VarNode(:a), - VarNode(:b) - ]), [:a, :b]) - constraint2 = Ordered(RuleNode(4, [ - VarNode(:a), - VarNode(:b) - ]), [:a, :b]) - addconstraint!(grammar, constraint1) - addconstraint!(grammar, constraint2) - - tree = UniformHole(BitVector((0, 0, 1, 1)), [ - UniformHole(BitVector((0, 0, 1, 1)), [ - UniformHole(BitVector((1, 1, 0, 0)), []), - UniformHole(BitVector((1, 1, 0, 0)), []) - ]), - UniformHole(BitVector((1, 1, 0, 0)), []) - ]) - solver = UniformSolver(grammar, tree) - @test isnothing(next_solution!(solver)) - end - - @testset "No solutions (forbidden constraint)" begin - grammar = @csgrammar begin - Number = 1 - Number = x - Number = Number + Number - Number = Number - Number - end - constraint1 = Forbidden(RuleNode(3, [ - VarNode(:a), - VarNode(:b) - ])) - constraint2 = Forbidden(RuleNode(4, [ - VarNode(:a), - VarNode(:b) - ])) - addconstraint!(grammar, constraint1) - addconstraint!(grammar, constraint2) - - tree = UniformHole(BitVector((0, 0, 1, 1)), [ - UniformHole(BitVector((0, 0, 1, 1)), [ - UniformHole(BitVector((1, 1, 0, 0)), []), - UniformHole(BitVector((1, 1, 0, 0)), []) - ]), - UniformHole(BitVector((1, 1, 0, 0)), []) - ]) - solver = UniformSolver(grammar, tree) - @test isnothing(next_solution!(solver)) - end -end diff --git a/test/test_state_manager.jl b/test/test_state_manager.jl index f1a5e95..a1bbe7f 100644 --- a/test/test_state_manager.jl +++ b/test/test_state_manager.jl @@ -96,4 +96,13 @@ @test get_value(a) == 10 @test get_value(b) == 100 end + + @testset "behavior needed for constraint unposting" begin + sm = HerbConstraints.StateManager() + save_state!(sm) + a = StateInt(sm, 0) # initialize a new state int, set its value to 0 (post a new constraint) + set_value!(a, 1) # immediately update the state int to 1 (activate the newly posted constraint) + restore!(sm) + @test get_value(a) == 0 # on backtrack, the new constraint should be deactivated + end end diff --git a/test/test_state_stack.jl b/test/test_state_stack.jl new file mode 100644 index 0000000..3fca689 --- /dev/null +++ b/test/test_state_stack.jl @@ -0,0 +1,74 @@ +@testset verbose=true "StateStack" begin + @testset "empty" begin + sm = HerbConstraints.StateManager() + stack = HerbConstraints.StateStack{Int}(sm) + @test size(stack) == 0 + + sm = HerbConstraints.StateManager() + stack = HerbConstraints.StateStack{String}(sm) + @test size(stack) == 0 + end + + @testset "from vector" begin + sm = HerbConstraints.StateManager() + stack = HerbConstraints.StateStack{Int}(sm, [10, 20, 30]) + values = collect(stack) + @test values[1] == 10 + @test values[2] == 20 + @test values[3] == 30 + + sm = HerbConstraints.StateManager() + stack = HerbConstraints.StateStack{String}(sm, ["A", "B", "C"]) + values = collect(stack) + @test values[1] == "A" + @test values[2] == "B" + @test values[3] == "C" + end + + @testset "restore" begin + sm = HerbConstraints.StateManager() + stack = HerbConstraints.StateStack{Int}(sm) + + push!(stack, 10); @test size(stack) == 1 + push!(stack, 20); @test size(stack) == 2 + + save_state!(sm); @test size(stack) == 2 + push!(stack, 30); @test size(stack) == 3 + push!(stack, 40); @test size(stack) == 4 + + restore!(sm); @test size(stack) == 2 + push!(stack, 50); @test size(stack) == 3 #overwrites "30" with "50" + + values = collect(stack) + @test values[1] == 10 + @test values[2] == 20 + @test values[3] == 50 + end + + @testset "restore twice" begin + sm = HerbConstraints.StateManager() + stack = HerbConstraints.StateStack{Int}(sm) + + push!(stack, 10); @test size(stack) == 1 #[10] + push!(stack, 20); @test size(stack) == 2 #[10, 20] + + save_state!(sm); @test size(stack) == 2 #[10, 20] + push!(stack, 30); @test size(stack) == 3 #[10, 20, 30] + push!(stack, 40); @test size(stack) == 4 #[10, 20, 30, 40] + + save_state!(sm); @test size(stack) == 4 #[10, 20, 30, 40] + push!(stack, 50); @test size(stack) == 5 #[10, 20, 30, 40, 50] + push!(stack, 60); @test size(stack) == 6 #[10, 20, 30, 40, 50, 60] + + restore!(sm); @test size(stack) == 4 #[10, 20, 30, 40] + push!(stack, 70); @test size(stack) == 5 #[10, 20, 30, 40, 70] + + restore!(sm); @test size(stack) == 2 #[10, 20] + push!(stack, 80); @test size(stack) == 3 #[10, 20, 80] + + values = collect(stack) + @test values[1] == 10 + @test values[2] == 20 + @test values[3] == 80 + end +end