Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Closures prototype #570

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/frontend/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ type typed_lval = (typed_expression, typed_expr_meta) lval_with
(** Statement shapes, where we substitute untyped_expression and untyped_statement
for 'e and 's respectively to get untyped_statement and typed_expression and
typed_statement to get typed_statement *)
type ('e, 's, 'l, 'f) statement =
type ('e, 's, 'l, 'f, 'c) statement =
| Assignment of
{ assign_lhs: 'l
; assign_op: assignmentoperator
Expand Down Expand Up @@ -155,6 +155,7 @@ type ('e, 's, 'l, 'f) statement =
| FunDef of
{ returntype: Middle.UnsizedType.returntype
; funname: identifier
; closure: 'c
; arguments:
(Middle.UnsizedType.autodifftype * Middle.UnsizedType.t * identifier)
list
Expand All @@ -176,13 +177,20 @@ type statement_returntype =
| AnyReturnType
[@@deriving sexp, hash, compare]

type ('e, 'm, 'l, 'f) statement_with =
{stmt: ('e, ('e, 'm, 'l, 'f) statement_with, 'l, 'f) statement; smeta: 'm}
type ('e, 'm, 'l, 'f, 'c) statement_with =
{ stmt: ('e, ('e, 'm, 'l, 'f, 'c) statement_with, 'l, 'f, 'c) statement
; smeta: 'm }
[@@deriving sexp, compare, map, hash]

type closure_info =
{ clname: string
; captures:
(Middle.UnsizedType.autodifftype * Middle.UnsizedType.t * string) list }
[@@deriving sexp, compare, map, hash]

(** Untyped statements, which have location_spans as meta-data *)
type untyped_statement =
(untyped_expression, located_meta, untyped_lval, unit) statement_with
(untyped_expression, located_meta, untyped_lval, unit, bool) statement_with
[@@deriving sexp, compare, map, hash]

let mk_untyped_statement ~stmt ~loc : untyped_statement = {stmt; smeta= {loc}}
Expand All @@ -198,7 +206,8 @@ type typed_statement =
( typed_expression
, stmt_typed_located_meta
, typed_lval
, fun_kind )
, fun_kind
, closure_info option )
statement_with
[@@deriving sexp, compare, map, hash]

Expand Down Expand Up @@ -246,7 +255,7 @@ let rec untyped_statement_of_typed_statement {stmt; smeta} =
map_statement untyped_expression_of_typed_expression
untyped_statement_of_typed_statement untyped_lvalue_of_typed_lvalue
(fun _ -> ())
stmt
Option.is_some stmt
; smeta= {loc= smeta.loc} }

(** Forgetful function from typed to untyped programs *)
Expand Down
64 changes: 61 additions & 3 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
open Core_kernel
open Middle

let closures = ref String.Map.empty

(* XXX fix exn *)
let unwrap_return_exn = function
| Some (UnsizedType.ReturnType ut) -> ut
Expand Down Expand Up @@ -72,6 +74,16 @@ and trans_expr {Ast.expr; Ast.emeta} =
| Variable {name; _} -> Var name
| IntNumeral x -> Lit (Int, format_number x)
| RealNumeral x -> Lit (Real, format_number x)
| FunApp
( Ast.StanLib
, {name= "integrate_ode_bdf"; _}
, ({emeta= {type_= UFun (fargs, _, _); _}; _} :: _ as args) )
when 2 + List.length fargs = List.length args ->
FunApp
( Fun_kind.StanLib
, "integrate_ode_bdf"
, trans_exprs args
@ Expr.Helpers.[float 1e-10; float 1e-10; float 1e8] )
| FunApp (fn_kind, {name; _}, args)
|CondDistApp (fn_kind, {name; _}, args) ->
FunApp (trans_fn_kind fn_kind, name, trans_exprs args)
Expand Down Expand Up @@ -467,6 +479,9 @@ let%expect_test "dist name suffix" =

let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
let stmt_typed = ts.stmt and smeta = ts.smeta.loc in
let trans_stmt_ad =
trans_stmt ud_dists {dconstrain= None; dadlevel= AutoDiffable}
in
let trans_stmt = trans_stmt ud_dists {declc with dconstrain= None} in
let trans_single_stmt s =
match trans_stmt s with
Expand Down Expand Up @@ -612,10 +627,52 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
; meta= smeta }
in
Stmt.Helpers.[ensure_var (for_each bodyfn) iteratee' smeta]
| Ast.FunDef _ ->
| Ast.FunDef {closure= None; _} ->
raise_s
[%message
"Found function definition statement outside of function block"]
| Ast.FunDef
{returntype; funname; closure= Some {clname; captures}; arguments; body}
->
let type_ =
UnsizedType.UFun
(List.map arguments ~f:(fun (a, t, _) -> (a, t)), returntype, Closure)
in
( match
Map.add !closures ~key:clname
~data:
{ Program.cdrt=
( match returntype with
| Void -> None
| ReturnType ut -> Some ut )
; cdcaptures= List.map captures ~f:(fun (a, t, n) -> (a, n, t))
; cdargs= List.map arguments ~f:trans_arg
; cdbody= trans_stmt_ad body |> unwrap_block_or_skip }
with
| `Ok x -> closures := x
| `Duplicate -> () ) ;
[ Stmt.
{ Fixed.pattern=
Decl
{ decl_adtype= declc.dadlevel
; decl_id= funname.name
; decl_type= Unsized type_ }
; meta= smeta }
; { pattern=
Assignment
( (funname.name, type_, [])
, Expr.
{ pattern=
FunApp
( CompilerInternal
, Internal_fun.to_string FnMakeClosure
, Helpers.str clname
:: List.map captures ~f:(fun (adlevel, type_, id) ->
{ Fixed.pattern= Var id
; meta= {Typed.Meta.type_; adlevel; loc= smeta}
} ) )
; meta= {type_; adlevel= declc.dadlevel; loc= smeta} } )
; meta= smeta } ]
| Ast.VarDecl
{decl_type; transformation; identifier; initial_value; is_global} ->
ignore is_global ;
Expand All @@ -631,7 +688,7 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =

let trans_fun_def ud_dists (ts : Ast.typed_statement) =
match ts.stmt with
| Ast.FunDef {returntype; funname; arguments; body} ->
| Ast.FunDef {returntype; funname; closure= None; arguments; body} ->
[ Program.
{ fdrt=
(match returntype with Void -> None | ReturnType ut -> Some ut)
Expand Down Expand Up @@ -765,7 +822,8 @@ let trans_prog filename (p : Ast.typed_program) : Program.Typed.t =
let transform_inits =
gen_from_block {declc with dconstrain= Some Unconstrain} Parameters
in
{ functions_block= map (trans_fun_def ud_dists) functionblock
{ closures= !closures
; functions_block= map (trans_fun_def ud_dists) functionblock
; input_vars
; prepare_data
; log_prob
Expand Down
14 changes: 10 additions & 4 deletions src/frontend/Canonicalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ let rec repair_syntax_stmt user_dists {stmt; smeta} =
{ stmt=
map_statement repair_syntax_expr
(repair_syntax_stmt user_dists)
repair_syntax_lval ident stmt
repair_syntax_lval ident ident stmt
; smeta }

let rec replace_deprecated_expr {expr; emeta} =
Expand Down Expand Up @@ -140,19 +140,25 @@ let rec replace_deprecated_stmt {stmt; smeta} =
{ assign_lhs= replace_deprecated_lval l
; assign_op= Assign
; assign_rhs= replace_deprecated_expr e }
| FunDef {returntype; funname= {name; id_loc}; arguments; body} ->
| FunDef
{ returntype
; funname= {name; id_loc}
; closure
; arguments
; body } ->
FunDef
{ returntype
; funname=
{ name=
Option.value ~default:name
(String.Table.find deprecated_userdefined name)
; id_loc }
; closure
; arguments
; body= replace_deprecated_stmt body }
| _ ->
map_statement replace_deprecated_expr replace_deprecated_stmt
replace_deprecated_lval ident stmt
replace_deprecated_lval ident ident stmt
in
{stmt; smeta}

Expand Down Expand Up @@ -209,7 +215,7 @@ let rec parens_stmt {stmt; smeta} =
; lower_bound= keep_parens lower_bound
; upper_bound= keep_parens upper_bound
; loop_body= parens_stmt loop_body }
| _ -> map_statement no_parens parens_stmt parens_lval ident stmt
| _ -> map_statement no_parens parens_stmt parens_lval ident ident stmt
in
{stmt; smeta}

Expand Down
5 changes: 3 additions & 2 deletions src/frontend/Pretty_printing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ and pp_unsizedtype ppf = function
let ut2, d = unwind_array_type ut in
let array_str = "[" ^ String.make d ',' ^ "]" in
Fmt.(suffix (const string array_str) pp_unsizedtype ppf ut2)
| UFun (argtypes, rt) ->
| UFun (argtypes, rt, _) ->
Fmt.pf ppf "{|@[<h>(%a) => %a@]|}"
Fmt.(list ~sep:comma_no_break pp_argtype)
argtypes pp_returntype rt
Expand Down Expand Up @@ -354,7 +354,8 @@ and pp_statement ppf ({stmt= s_content; _} as ss) =
with_hbox ppf (fun () ->
Fmt.pf ppf "%a %a%a%a;" pp_transformed_type (pst, trans)
pp_identifier id pp_array_dims es pp_init init )
| FunDef {returntype= rt; funname= id; arguments= args; body= b} -> (
| FunDef {returntype= rt; funname= id; arguments= args; body= b; closure= _}
-> (
Fmt.pf ppf "%a %a(" pp_returntype rt pp_identifier id ;
with_box ppf 0 (fun () ->
Fmt.pf ppf "%a" (Fmt.list ~sep:Fmt.comma pp_args) args ) ;
Expand Down
Loading