Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closures (again) #742

Open
wants to merge 38 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b1fc74e
add closure type
nhuurre Nov 15, 2020
3a1c303
add closure support to fun_defs
nhuurre Nov 15, 2020
9e55fed
refactor function codegen
nhuurre Nov 16, 2020
8dcaf3e
make parser and Symbol_table support closures
nhuurre Nov 18, 2020
819fdde
user-defined higher-order functions
nhuurre Nov 28, 2020
c2e1e51
fix closures that capture closures
nhuurre Nov 28, 2020
c80a1fc
Merge branch 'master' into feature/closures
nhuurre Nov 30, 2020
0c00700
better error messages
nhuurre Dec 1, 2020
9bc95ff
enable reduce_sum closures
nhuurre Dec 1, 2020
a4c0157
fix
nhuurre Dec 3, 2020
58ec0a0
enable capture-by-reference
nhuurre Dec 19, 2020
07e489a
Merge branch 'master' into feature/closures
nhuurre Dec 19, 2020
1118f2b
Merge branch 'master' into feature/closures
nhuurre Jan 29, 2021
3e2b7ba
Merge branch 'master' into feature/closures
nhuurre Feb 18, 2021
935bc54
tweak C++ API
nhuurre Feb 19, 2021
1282f70
allow special suffix closures
nhuurre Feb 20, 2021
f4cf3e6
higher-order userdef suffix functions
nhuurre Feb 21, 2021
b4abd2b
restore missing functor structs
nhuurre Feb 23, 2021
0d20a92
Merge branch 'master' into feature/closures
nhuurre Feb 23, 2021
8e07220
Merge branch 'master' into feature/closures
nhuurre Mar 20, 2021
6f1fe48
new keyword for closures
nhuurre Mar 20, 2021
9fd3f9f
Merge branch 'ode_adjoint' into feature/closures
nhuurre Mar 20, 2021
9e035c0
Merge branch 'master' into feature/closures
nhuurre May 24, 2021
2e6b75c
organize test files
nhuurre May 24, 2021
5b4e572
Merge branch 'feature/type-error-explanations' into feature/closures
nhuurre Jun 8, 2021
f352eff
post-merge cleanup
nhuurre Jun 8, 2021
73cf174
Merge remote-tracking branch 'nhuurre/closures-adjoint-ode' into feat…
nhuurre Jun 8, 2021
a1b0b85
Merge branch 'master' into feature/closures
nhuurre Jun 15, 2021
ce8f83a
simplify ode interface
nhuurre Jun 15, 2021
ef1604d
eval() eigen expressions
nhuurre Jun 15, 2021
2ba4cb9
rename closure test files
nhuurre Jun 16, 2021
0b01625
Merge branch 'master' into feature/closures
nhuurre Jul 19, 2021
15037f4
use from_lambda more
nhuurre Jul 19, 2021
97df74c
use *_from_lambda() instead of custom closure structs
nhuurre Jul 27, 2021
c20d155
dune promote
nhuurre Jul 27, 2021
e5d6414
fix _lp and _rng in higher-order functions
nhuurre Aug 12, 2021
131bb90
Merge branch 'master' into feature/closures
nhuurre Sep 23, 2021
4d39c38
Merge branch 'master' into feature/closures
nhuurre Oct 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ let rec inline_function_expression propto adt fim
in
let es = List.map ~f:(function _, _, x -> x) dse_list in
match kind with
| CompilerInternal _ ->
| CompilerInternal _ | Closure _ ->
(d_list, s_list, {e with pattern= FunApp (kind, es)})
| UserDefined (fname, suffix) | StanLib (fname, suffix) -> (
let suffix, fname' =
Expand Down Expand Up @@ -382,7 +382,7 @@ let rec inline_function_statement propto adt fim Stmt.Fixed.({pattern; meta}) =
let es = List.map ~f:(function _, _, x -> x) dse_list in
slist_concat_no_loc (d_list @ s_list)
( match kind with
| CompilerInternal _ -> NRFunApp (kind, es)
| CompilerInternal _ | Closure _ -> NRFunApp (kind, es)
| UserDefined (s, _) | StanLib (s, _) -> (
match Map.find fim s with
| None -> NRFunApp (kind, es)
Expand Down
2 changes: 1 addition & 1 deletion src/analysis_and_optimization/Partial_evaluator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ let rec eval_expr (e : Expr.Typed.t) =
| FunApp (kind, l) -> (
let l = List.map ~f:eval_expr l in
match kind with
| UserDefined _ | CompilerInternal _ -> FunApp (kind, l)
| UserDefined _ | CompilerInternal _ | Closure _ -> FunApp (kind, l)
| StanLib (f, suffix) ->
let get_fun_or_op_rt_opt name l' =
let argument_types =
Expand Down
30 changes: 22 additions & 8 deletions src/frontend/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ type 'e index =
type fun_kind =
| StanLib of bool Fun_kind.suffix
| UserDefined of bool Fun_kind.suffix
| Closure of bool Fun_kind.suffix
[@@deriving compare, sexp, hash]

type capture_info =
string
* ( Middle.UnsizedType.capturetype
* Middle.UnsizedType.autodifftype
* Middle.UnsizedType.t
* string )
list
[@@deriving compare, sexp, hash]

(** Expression shapes (used for both typed and untyped expressions, where we
Expand Down Expand Up @@ -117,7 +127,7 @@ type typed_lval = (typed_expression, typed_expr_meta) lval_with
(** Statement shapes, where we substitute untyped_expression and untyped_statement
for 'e and 's respectively to get untyped_statement and typed_expression and
typed_statement to get typed_statement *)
type ('e, 's, 'l, 'f) statement =
type ('e, 's, 'l, 'f, 'c) statement =
| Assignment of
{ assign_lhs: 'l
; assign_op: assignmentoperator
Expand Down Expand Up @@ -157,6 +167,7 @@ type ('e, 's, 'l, 'f) statement =
| FunDef of
{ returntype: Middle.UnsizedType.returntype
; funname: identifier
; captures: 'c option
; arguments:
(Middle.UnsizedType.autodifftype * Middle.UnsizedType.t * identifier)
list
Expand All @@ -178,13 +189,14 @@ type statement_returntype =
| AnyReturnType
[@@deriving sexp, hash, compare]

type ('e, 'm, 'l, 'f) statement_with =
{stmt: ('e, ('e, 'm, 'l, 'f) statement_with, 'l, 'f) statement; smeta: 'm}
type ('e, 'm, 'l, 'f, 'c) statement_with =
{ stmt: ('e, ('e, 'm, 'l, 'f, 'c) statement_with, 'l, 'f, 'c) statement
; smeta: 'm }
[@@deriving sexp, compare, map, hash, fold]

(** Untyped statements, which have location_spans as meta-data *)
type untyped_statement =
(untyped_expression, located_meta, untyped_lval, unit) statement_with
(untyped_expression, located_meta, untyped_lval, unit, unit) statement_with
[@@deriving sexp, compare, map, hash]

let mk_untyped_statement ~stmt ~loc : untyped_statement = {stmt; smeta= {loc}}
Expand All @@ -200,7 +212,8 @@ type typed_statement =
( typed_expression
, stmt_typed_located_meta
, typed_lval
, fun_kind )
, fun_kind
, capture_info )
statement_with
[@@deriving sexp, compare, map, hash]

Expand Down Expand Up @@ -243,12 +256,13 @@ let rec untyped_lvalue_of_typed_lvalue ({lval; lmeta} : typed_lval) :
; lmeta= {loc= lmeta.loc} }

(** Forgetful function from typed to untyped statements *)
let rec untyped_statement_of_typed_statement {stmt; smeta} =
let rec untyped_statement_of_typed_statement ({stmt; smeta} : typed_statement)
=
{ stmt=
map_statement untyped_expression_of_typed_expression
untyped_statement_of_typed_statement untyped_lvalue_of_typed_lvalue
(fun _ -> ())
stmt
(function StanLib _ | UserDefined _ | Closure _ -> ())
Fn.ignore stmt
; smeta= {loc= smeta.loc} }

(** Forgetful function from typed to untyped programs *)
Expand Down
76 changes: 69 additions & 7 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ let trans_fn_kind kind name =
match kind with
| Ast.StanLib suffix -> Fun_kind.StanLib (fname, suffix)
| UserDefined suffix -> UserDefined (fname, suffix)
| Closure suffix -> Closure (fname, suffix)

let without_underscores = String.filter ~f:(( <> ) '_')

Expand All @@ -36,6 +37,8 @@ let%expect_test "format_number1" =
format_number ".123_456" |> print_endline ;
[%expect ".123456"]

let closures = ref String.Map.empty

let rec op_to_funapp op args =
let argtypes =
List.map ~f:(fun x -> (x.Ast.emeta.Ast.ad_level, x.emeta.type_)) args
Expand Down Expand Up @@ -497,9 +500,9 @@ let unwrap_block_or_skip = function

let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
let stmt_typed = ts.stmt and smeta = ts.smeta.loc in
let trans_stmt = trans_stmt ud_dists {declc with dconstrain= None} in
let translate_stmt = trans_stmt ud_dists {declc with dconstrain= None} in
let trans_single_stmt s =
match trans_stmt s with
match translate_stmt s with
| [s] -> s
| s -> Stmt.Fixed.{pattern= SList s; meta= smeta}
in
Expand Down Expand Up @@ -633,7 +636,63 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
; meta= smeta }
in
Stmt.Helpers.[ensure_var (for_each bodyfn) iteratee' smeta]
| Ast.FunDef _ ->
| Ast.FunDef
{ returntype
; funname
; captures= Some (implname, captures)
; arguments
; body } ->
let fdsuffix = Fun_kind.suffix_from_name funname.name in
if Map.find !closures implname = None then
closures :=
String.Map.add_exn !closures ~key:implname
~data:
{ Program.fdrt=
( match returntype with
| Void -> None
| ReturnType ut -> Some ut )
; fdname= implname
; fdsuffix=
Fun_kind.(
suffix_from_name funname.name |> map_suffix Fn.ignore)
; fdcaptures=
Some
(List.map
~f:(fun (ref, ad, ty, id) -> (ref, ad, id, ty))
captures)
; fdargs= List.map ~f:trans_arg arguments
; fdbody=
trans_stmt ud_dists
{dconstrain= None; dadlevel= AutoDiffable}
body
|> unwrap_block_or_skip
; fdloc= ts.smeta.loc } ;
let arguments = List.map ~f:(fun (ad, ut, _) -> (ad, ut)) arguments in
let type_ = UnsizedType.UFun (arguments, returntype, (fdsuffix, true)) in
let captures =
List.map
~f:(fun (_, adlevel, type_, id) ->
Expr.
{ Fixed.pattern= Var id
; meta= Typed.Meta.{adlevel; type_; loc= mloc} } )
captures
in
[ { pattern=
Decl
{ decl_adtype= AutoDiffable
; decl_id= funname.name
; decl_type= Unsized type_ }
; meta= smeta }
; { pattern=
Assignment
( (funname.name, type_, [])
, { pattern=
FunApp
( CompilerInternal FnMakeClosure
, Expr.Helpers.str implname :: captures )
; meta= {type_; adlevel= AutoDiffable; loc= mloc} } )
; meta= smeta } ]
| Ast.FunDef {captures= None; _} ->
raise_s
[%message
"Found function definition statement outside of function block"]
Expand All @@ -642,9 +701,9 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
trans_decl declc smeta decl_type
(Program.map_transformation trans_expr transformation)
identifier initial_value
| Ast.Block stmts -> Block (List.concat_map ~f:trans_stmt stmts) |> swrap
| Ast.Block stmts -> Block (List.concat_map ~f:translate_stmt stmts) |> swrap
| Ast.Profile (name, stmts) ->
Profile (name, List.concat_map ~f:trans_stmt stmts) |> swrap
Profile (name, List.concat_map ~f:translate_stmt stmts) |> swrap
| Ast.Return e -> Return (Some (trans_expr e)) |> swrap
| Ast.ReturnVoid -> Return None |> swrap
| Ast.Break -> Break |> swrap
Expand All @@ -653,13 +712,14 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =

let trans_fun_def ud_dists (ts : Ast.typed_statement) =
match ts.stmt with
| Ast.FunDef {returntype; funname; arguments; body} ->
| Ast.FunDef {returntype; funname; captures= None; arguments; body} ->
[ Program.
{ fdrt=
(match returntype with Void -> None | ReturnType ut -> Some ut)
; fdname= funname.name
; fdsuffix=
Fun_kind.(suffix_from_name funname.name |> without_propto)
; fdcaptures= None
; fdargs= List.map ~f:trans_arg arguments
; fdbody=
trans_stmt ud_dists
Expand Down Expand Up @@ -838,6 +898,7 @@ let migrate_checks_to_end_of_block stmts =
not_checks @ checks

let trans_prog filename (p : Ast.typed_program) : Program.Typed.t =
closures := String.Map.empty ;
let {Ast.functionblock; datablock; transformeddatablock; modelblock; _} =
p
in
Expand Down Expand Up @@ -945,7 +1006,8 @@ let trans_prog filename (p : Ast.typed_program) : Program.Typed.t =
"_" ^ prog_name
else prog_name
in
{ functions_block= map (trans_fun_def ud_dists) functionblock
let functions = map (trans_fun_def ud_dists) functionblock in
{ functions_block= functions @ List.map ~f:snd (Map.to_alist !closures)
; input_vars
; prepare_data
; log_prob
Expand Down
12 changes: 8 additions & 4 deletions src/frontend/Canonicalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ let rec repair_syntax_stmt user_dists {stmt; smeta} =
; smeta }
| _ ->
{ stmt=
map_statement ident (repair_syntax_stmt user_dists) ident ident stmt
map_statement Fn.id
(repair_syntax_stmt user_dists)
Fn.id Fn.id Fn.id stmt
; smeta }

let rec replace_deprecated_expr
Expand Down Expand Up @@ -91,7 +93,8 @@ let rec replace_deprecated_stmt
{ assign_lhs= replace_deprecated_lval deprecated_userdefined l
; assign_op= Assign
; assign_rhs= (replace_deprecated_expr deprecated_userdefined) e }
| FunDef {returntype; funname= {name; id_loc}; arguments; body} ->
| FunDef {returntype; funname= {name; id_loc}; captures; arguments; body}
->
let newname =
match String.Map.find deprecated_userdefined name with
| Some type_ -> update_suffix name type_
Expand All @@ -100,14 +103,15 @@ let rec replace_deprecated_stmt
FunDef
{ returntype
; funname= {name= newname; id_loc}
; captures
; arguments
; body= replace_deprecated_stmt deprecated_userdefined body }
| _ ->
map_statement
(replace_deprecated_expr deprecated_userdefined)
(replace_deprecated_stmt deprecated_userdefined)
(replace_deprecated_lval deprecated_userdefined)
ident stmt
Fn.id Fn.id stmt
in
{stmt; smeta}

Expand Down Expand Up @@ -164,7 +168,7 @@ let rec parens_stmt {stmt; smeta} =
; lower_bound= keep_parens lower_bound
; upper_bound= keep_parens upper_bound
; loop_body= parens_stmt loop_body }
| _ -> map_statement no_parens parens_stmt parens_lval ident stmt
| _ -> map_statement no_parens parens_stmt parens_lval Fn.id Fn.id stmt
in
{stmt; smeta}

Expand Down
1 change: 1 addition & 0 deletions src/frontend/Deprecation_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ let rec collect_deprecated_stmt deprecated_userdefined
(collect_deprecated_stmt deprecated_userdefined)
(collect_deprecated_lval deprecated_userdefined)
(fun l _ -> l)
(fun l _ -> l)
acc stmt

let collect_userdef_distributions program =
Expand Down
1 change: 1 addition & 0 deletions src/frontend/Info.ml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ let rec get_function_calls_stmt ud_dists (funs, distrs) stmt =
(get_function_calls_stmt ud_dists)
(fun acc _ -> acc)
(fun acc _ -> acc)
(fun acc _ -> acc)
acc stmt.stmt

let function_calls ppf p =
Expand Down
16 changes: 7 additions & 9 deletions src/frontend/Pretty_printing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,17 @@ let with_indented_box ppf indentation offset f =

let pp_unsizedtype = Middle.UnsizedType.pp
let pp_autodifftype = Middle.UnsizedType.pp_autodifftype
let pp_returntype = Middle.UnsizedType.pp_returntype

let rec unwind_sized_array_type = function
| Middle.SizedType.SArray (st, e) -> (
match unwind_sized_array_type st with st2, es -> (st2, es @ [e]) )
| st -> (st, [])

let pp_unsizedtypes ppf l = Fmt.(list ~sep:comma_no_break pp_unsizedtype) ppf l

let pp_argtype ppf = function
| at, ut -> Fmt.pair ~sep:Fmt.nop pp_autodifftype pp_unsizedtype ppf (at, ut)

let pp_returntype ppf = function
| Middle.UnsizedType.ReturnType x -> pp_unsizedtype ppf x
| Void -> Fmt.pf ppf "void"
let rec unwind_array_type = function
| Middle.UnsizedType.UArray ut -> (
match unwind_array_type ut with ut2, d -> (ut2, d + 1) )
| ut -> (ut, 0)

let pp_identifier ppf id = Fmt.pf ppf "%s" id.name

Expand Down Expand Up @@ -338,7 +335,8 @@ and pp_statement ppf ({stmt= s_content; _} as ss) =
with_hbox ppf (fun () ->
Fmt.pf ppf "%a%a %a%a;" pp_array_dims es pp_transformed_type
(pst, trans) pp_identifier id pp_init init )
| FunDef {returntype= rt; funname= id; arguments= args; body= b} -> (
| FunDef {returntype= rt; funname= id; captures; arguments= args; body= b} -> (
if is_some captures then Fmt.pf ppf "function" ;
Fmt.pf ppf "%a %a(" pp_returntype rt pp_identifier id ;
with_box ppf 0 (fun () ->
Fmt.pf ppf "%a" (Fmt.list ~sep:Fmt.comma pp_args) args ) ;
Expand Down
Loading