diff --git a/regression_ppx/dune b/regression_ppx/dune index 05771aa4..a881ac6e 100644 --- a/regression_ppx/dune +++ b/regression_ppx/dune @@ -148,13 +148,29 @@ (modules test013mutual) (package OCanren) (public_names -) - (flags - (:standard - ;-dsource - ; - )) (preprocess - (pps OCanren-ppx.ppx_distrib GT.ppx_all -- -new-typenames -pretty)) + (pps + OCanren-ppx.ppx_distrib + GT.ppx_all + OCanren-ppx.ppx_fresh + -- + -new-typenames + -pretty)) + (libraries OCanren OCanren.tester)) + +(executables + (names test014diseq) + (modules test014diseq) + (package OCanren) + (public_names -) + (preprocess + (pps + GT.ppx_all + OCanren-ppx.ppx_fresh + OCanren-ppx.ppx_tester + OCanren-ppx.ppx_repr + -- + -pretty)) (libraries OCanren OCanren.tester)) (cram @@ -283,3 +299,12 @@ %{project_root}/ppx/pp_ocanren_all.exe test013mutual.ml test013mutual.exe)) + +(cram + (package OCanren) + (applies_to test014) + (deps + (package OCanren-ppx) + %{project_root}/ppx/pp_ocanren_all.exe + test014diseq.ml + test014diseq.exe)) diff --git a/regression_ppx/test014.t b/regression_ppx/test014.t new file mode 100644 index 00000000..e71eccae --- /dev/null +++ b/regression_ppx/test014.t @@ -0,0 +1,35 @@ + $ ./test014diseq.exe + rel, 1 answer { + hd1 = _.12 + hd2 = _.14 + tl2 = _.15 + 10: { 0: [| 10 =/= _.11 |] } + + 11: { 0: [| 11 =/= boxed 0 <_.12, _.13> |] } + + 15: { 0: [| 14 =/= _.12, 15 =/= _.13 |] } + + hd2 === 1 + 15: { 0: [| 14 =/= _.12, 15 =/= _.13 |] } + + tl2 === [] + 13: { 0: [| 12 =/= int<1>, 13 =/= int<0> |] } + + 47 + 13: { 0: [| 12 =/= int<1>, 13 =/= int<0> |] } + + } + fun _ -> + fresh x + ((Std.list Fun.id [!< x; !< x]) =/= + (Std.list Fun.id [!< (!! 1); !< (!! 2)])), 1 answer { + q=_.10; + } + fun q -> + fresh (x y) (trace_index "x" x) (trace_index "y" y) ((x % y) === q) + ((x % y) =/= (Std.list Fun.id [!! 1; x])) + (y === (Std.list Fun.id [!! 2])) success, 1 answer { + x = _.11 + y = _.12 + q=[_.11; 2]; + } diff --git a/regression_ppx/test014diseq.ml b/regression_ppx/test014diseq.ml new file mode 100644 index 00000000..3461120f --- /dev/null +++ b/regression_ppx/test014diseq.ml @@ -0,0 +1,87 @@ +open OCanren +open Tester + +let debug_line line = + debug_var !!1 OCanren.reify (function _ -> + Format.printf "%d\n%!" line; + success) +;; + +let trace_index msg var = + debug_var var OCanren.reify (function + | [ Var (n, _) ] -> + Printf.printf "%s = _.%d\n" msg n; + success + | _ -> assert false) +;; + +let trace fmt = + Format.kasprintf + (fun s -> + debug_var !!1 OCanren.reify (function _ -> + Format.printf "%s\n%!" s; + success)) + fmt +;; + +let rel list1 = + let open OCanren.Std in + fresh + (list2 hd1 tl1 hd2 tl2) + (trace_index "hd1" hd1) + (trace_index "hd2" hd2) + (trace_index "tl2" tl2) + (list1 =/= list2) + trace_diseq + (list1 === hd1 % tl1) + trace_diseq + (list2 === hd2 % tl2) + trace_diseq + (trace " hd2 === 1") + (hd2 === !!1) + trace_diseq + (trace " tl2 === []") + (tl2 === nil ()) + trace_diseq + (hd1 === !!1) + (debug_line __LINE__) + trace_diseq + (tl1 === nil ()) (* crashes here *) + (debug_line __LINE__) +;; + +(* let () = [%tester run_r [%show GT.int GT.list] (Std.List.reify reify) 1 (fun q -> rel q)] *) +let () = run_r (Std.List.reify reify) ([%show: GT.int logic Std.List.logic] ()) 1 q qh (REPR rel) + +let () = + let open Std in + run_r + (Std.List.reify reify) + ([%show: GT.int logic Std.List.logic] ()) + 1 + q + qh + (REPR (fun _ -> fresh x (Std.list Fun.id [ ! + fresh + (x y) + (trace_index "x" x) + (trace_index "y" y) + (x % y === q) + (x % y =/= Std.list Fun.id [ !!1; x ]) + (* trace_diseq *) + (y === Std.list Fun.id [ !!2 ]) + (* trace_diseq *) + success)) +;; diff --git a/src/core/Core.ml b/src/core/Core.ml index 954cbde3..b987b54c 100644 --- a/src/core/Core.ml +++ b/src/core/Core.ml @@ -795,4 +795,9 @@ module Tabling = let reify_in_empty reifier x = let st = State.empty () in - reifier (State.env st) x \ No newline at end of file + reifier (State.env st) x + +let trace_diseq : goal = fun st -> + Format.printf "%a\n%!" Disequality.pp (State.constraints st); + success st + diff --git a/src/core/Core.mli b/src/core/Core.mli index fff880bc..795be3cc 100644 --- a/src/core/Core.mli +++ b/src/core/Core.mli @@ -320,4 +320,6 @@ module PrunesControl : sig end (** Runs reifier on empty state. Useful to debug execution order *) -val reify_in_empty: ('a, 'b) Reifier.t -> 'a -> 'b \ No newline at end of file +val reify_in_empty: ('a, 'b) Reifier.t -> 'a -> 'b + +val trace_diseq: goal \ No newline at end of file diff --git a/src/core/Disequality.ml b/src/core/Disequality.ml index 8fab0dfd..f60adce0 100644 --- a/src/core/Disequality.ml +++ b/src/core/Disequality.ml @@ -19,6 +19,11 @@ (* to avoid clash with Std.List (i.e. logic list) *) module List = Stdlib.List +let log fmt = + if false + then Format.kasprintf (Format.printf "%s\n%!") fmt + else Format.ifprintf Format.std_formatter fmt + module Answer = struct module S = Set.Make(Term) @@ -91,6 +96,9 @@ module Disjunct : (* Disjunction.t is a set of single disequalities joint by disjunction *) type t + + val pp : Format.formatter -> t -> unit + (* [make env subst x y] creates new disjunct from the disequality [x =/= y] *) val make : Env.t -> Subst.t -> 'a -> 'a -> t @@ -118,9 +126,19 @@ module Disjunct : struct type t = Term.t Term.VarMap.t - let update t = - ListLabels.fold_left ~init:t - ~f:(let open Subst.Binding in fun acc {var; term} -> + let pp ppf d = + if Term.VarMap.is_empty d then Format.fprintf ppf "" + else + Format.fprintf ppf "[| "; + Term.VarMap.iteri (fun i k v -> + if i<>0 then Format.fprintf ppf ", "; + Format.fprintf ppf " @[%d =/= %s@]" k.Term.Var.index (Term.show v) + ) d; + Format.fprintf ppf " |]" + + let update : t -> _ -> t = fun init -> + ListLabels.fold_left ~init + ~f:(fun acc {Subst.Binding.var; term} -> if Term.VarMap.mem var acc then (* in this case we have subformula of the form (x =/= t1) \/ (x =/= t2) which is always SAT *) raise Disequality_fulfilled @@ -149,13 +167,29 @@ module Disjunct : | Fulfiled -> raise Disequality_fulfilled | Violated -> raise Disequality_violated - let rec recheck env subst t = + let rec recheck env subst (t: t): t = + (* log "Disjunct.recheck: %a" pp t; *) let var, term = Term.VarMap.max_binding t in - let unchecked = Term.VarMap.remove var t in + (* log " max bind index = %d" var.Term.Var.index; *) match refine env subst (Obj.magic var) term with - | Fulfiled -> raise Disequality_fulfilled - | Refined delta -> update unchecked delta + | Fulfiled -> + raise Disequality_fulfilled + | Refined delta -> ( + (* When leading terms are reified into something new, we still need to + do whole unification, because other pairs may need walking --- + (we postponed walking, so some information may be lost.) + See issue #173 + *) + (* log "Refined into: %a" (Format.pp_print_list Subst.Binding.pp) delta; *) + match Subst.unify_map env subst t with + | None -> + (* not unifiable --- always distinct *) + raise Disequality_fulfilled + | Some (delta, _) when Term.VarMap.is_empty delta -> raise Disequality_violated + | Some (bnds, _subst) -> bnds) | Violated -> + let unchecked = Term.VarMap.remove var t in + (* log " unchecked: %a" pp unchecked; *) if Term.VarMap.is_empty unchecked then raise Disequality_violated else @@ -208,6 +242,8 @@ module Conjunct : val empty : t + val pp : Format.formatter -> t -> unit + val is_empty : t -> bool val make : Env.t -> Subst.t -> 'a -> 'a -> t @@ -236,6 +272,19 @@ module Conjunct : type t = Disjunct.t M.t + let pp ppf map = + if M.is_empty map + then Format.fprintf ppf "{}" + else + let idx = ref 0 in + Format.fprintf ppf "{ "; + M.iter (fun k v -> + if !idx <> 0 then Format.fprintf ppf " ,"; + Format.fprintf ppf "@[%d: %a@]" k Disjunct.pp v; + incr idx + ) map; + Format.fprintf ppf " }" + let empty = M.empty let is_empty = M.is_empty @@ -256,11 +305,14 @@ module Conjunct : ) t Term.VarMap.empty let recheck env subst t = - M.fold (fun id disj acc -> + (* log "Conjunct.recheck. %a" pp t; *) + let rez = M.fold (fun id disj acc -> try M.add id (Disjunct.recheck env subst disj) acc with Disequality_fulfilled -> acc - ) t M.empty + ) t M.empty in + (* log "rechecked = %a" pp rez; *) + rez let merge_disjoint env subst = M.union (fun _ _ _ -> @@ -351,6 +403,11 @@ type t = Conjunct.t Term.VarMap.t let empty = Term.VarMap.empty +let pp ppf : t -> unit = + Term.VarMap.iter (fun k v -> + Format.fprintf ppf "@[%d: %a@]@," k.Term.Var.index Conjunct.pp v + ) + (* merges all conjuncts (linked to different variables) into one *) let combine env subst cstore = Term.VarMap.fold (fun _ -> Conjunct.merge_disjoint env subst) cstore Conjunct.empty @@ -370,17 +427,19 @@ let add env subst cstore x y = | Disequality_violated -> None let recheck env subst cstore bs = - let helper var cstore = + let helper var cstore : t = try let conj = Term.VarMap.find var cstore in let cstore = Term.VarMap.remove var cstore in update env subst (Conjunct.recheck env subst conj) cstore + with Not_found -> cstore in try let cstore = ListLabels.fold_left bs ~init:cstore - ~f:(let open Subst.Binding in fun cstore {var; term} -> + ~f:(fun cstore {Subst.Binding.var; term} -> let cstore = helper var cstore in + (* log "cstore = %a" pp cstore; *) match Env.var env term with | Some u -> helper u cstore | None -> cstore diff --git a/src/core/Disequality.mli b/src/core/Disequality.mli index 582d6216..72546a8a 100644 --- a/src/core/Disequality.mli +++ b/src/core/Disequality.mli @@ -51,3 +51,5 @@ module Answer : end val reify : Env.t -> Subst.t -> t -> 'a -> Answer.t list + +val pp: Format.formatter -> t -> unit diff --git a/src/core/Subst.ml b/src/core/Subst.ml index 8aa8b253..9952f402 100644 --- a/src/core/Subst.ml +++ b/src/core/Subst.ml @@ -46,8 +46,18 @@ module Binding = if res <> 0 then res else Term.compare t p let hash {var; term} = Hashtbl.hash (Term.Var.hash var, Term.hash term) + + let pp ppf {var; term} = + Format.fprintf ppf "{ var.idx = %d; term=%s }" var.Term.Var.index (Term.show term) end +let varmap_of_bindings : Binding.t list -> Term.t Term.VarMap.t = + Stdlib.List.fold_left (fun (acc: _ Term.VarMap.t) Binding.{var;term} -> + assert (not (Term.VarMap.mem var acc)); + Term.VarMap.add var term acc + ) + Term.VarMap.empty + type t = Term.t Term.VarMap.t let empty = Term.VarMap.empty @@ -145,37 +155,57 @@ let extend ~scope env subst var term = exception Unification_failed -let unify ?(subsume=false) ?(scope=Term.Var.non_local_scope) env subst x y = +let log fmt = + if false + then Format.kasprintf (Format.printf "%s\n%!") fmt + else Format.ifprintf Format.std_formatter fmt + +let ext ~scope ~env add_delta var term (prefix, subst) = + let subst = extend ~scope env subst var term in + (add_delta Binding.{var; term} prefix, subst) + +let rec unify_helper subsume ext env x y acc = + (* log "unify '%s' and '%s'" (Term.show x) (Term.show y); *) + let open Term in + fold2 x y ~init:acc + ~fvar:(fun ((_, subst) as acc) x y -> + match walk env subst x, walk env subst y with + | Var x, Var y -> + if Var.equal x y then acc else ext x (Term.repr y) acc + | Var x, Value y -> ext x y acc + | Value x, Var y -> ext y x acc + | Value x, Value y -> unify_helper subsume ext env x y acc + ) + ~fval:(fun acc x y -> + if x = y then acc else raise Unification_failed + ) + ~fk:(fun ((_, subst) as acc) l v y -> + if subsume && (l = Term.R) + then raise Unification_failed + else match walk env subst v with + | Var v -> ext v y acc + | Value x -> unify_helper subsume ext env x y acc + ) + +let unify_gen ?(subsume=false) ?(scope=Term.Var.non_local_scope) add_delta empty_delta env subst x y = (* The idea is to do the unification and collect the unification prefix during the process *) - let extend var term (prefix, subst) = - let subst = extend ~scope env subst var term in - (Binding.({var; term})::prefix, subst) - in - let rec helper x y acc = - let open Term in - fold2 x y ~init:acc - ~fvar:(fun ((_, subst) as acc) x y -> - match walk env subst x, walk env subst y with - | Var x, Var y -> - if Var.equal x y then acc else extend x (Term.repr y) acc - | Var x, Value y -> extend x y acc - | Value x, Var y -> extend y x acc - | Value x, Value y -> helper x y acc - ) - ~fval:(fun acc x y -> - if x = y then acc else raise Unification_failed - ) - ~fk:(fun ((_, subst) as acc) l v y -> - if subsume && (l = Term.R) - then raise Unification_failed - else match walk env subst v with - | Var v -> extend v y acc - | Value x -> helper x y acc - ) - in try let x, y = Term.(repr x, repr y) in - Some (helper x y ([], subst)) + Some (unify_helper subsume (ext ~scope ~env add_delta) env x y (empty_delta, subst)) + with Term.Different_shape _ | Unification_failed | Occurs_check -> None + +let unify ?(subsume=false) ?(scope=Term.Var.non_local_scope) = + unify_gen ~subsume ~scope List.cons [] + +let unify_map env subst map : (Obj.t Term.VarMap.t * t) option = + let add_delta {Binding.var; term} m = Term.VarMap.add var term m in + let subsume = false in + let scope = Term.Var.non_local_scope in + try + Stdlib.Option.some @@ + Term.VarMap.fold (fun var term acc -> + unify_helper subsume (ext ~scope ~env add_delta) env (Obj.magic var) term acc + ) map (Term.VarMap.empty, subst) with Term.Different_shape _ | Unification_failed | Occurs_check -> None let apply env subst x = Obj.magic @@ diff --git a/src/core/Subst.mli b/src/core/Subst.mli index b0312209..60b6c6b2 100644 --- a/src/core/Subst.mli +++ b/src/core/Subst.mli @@ -29,8 +29,11 @@ module Binding : val equal : t -> t -> bool val compare : t -> t -> int val hash : t -> int + val pp: Format.formatter -> t -> unit end +val varmap_of_bindings: Binding.t list -> Term.t Term.VarMap.t + type t val empty : t @@ -64,6 +67,8 @@ val freevars : Env.t -> t -> 'a -> Term.VarSet.t *) val unify : ?subsume:bool -> ?scope:Term.Var.scope -> Env.t -> t -> 'a -> 'a -> (Binding.t list * t) option +val unify_map: Env.t -> t -> Term.t Term.VarMap.t -> (Obj.t Term.VarMap.t * t) option + val merge_disjoint : Env.t -> t -> t -> t (* [merge env s1 s2] merges two substituions *) diff --git a/src/core/Term.ml b/src/core/Term.ml index b9612416..1f0bf166 100644 --- a/src/core/Term.ml +++ b/src/core/Term.ml @@ -90,6 +90,10 @@ module VarMap = match f (try Some (find k m) with Not_found -> None) with | Some x -> add k x m | None -> remove k m + + let iteri f m = + let i = ref 0 in + iter (fun k v -> f !i k v; incr i) m end type t = Obj.t diff --git a/src/core/Term.mli b/src/core/Term.mli index bf41ffa5..e99fd6f4 100644 --- a/src/core/Term.mli +++ b/src/core/Term.mli @@ -67,6 +67,9 @@ module VarMap : include Map.S with type key = Var.t val update : key -> ('a option -> 'a option) -> 'a t -> 'a t + + val iteri: (int -> key -> 'a -> unit) -> 'a t -> unit + end (* [t] type of untyped OCaml term *)