Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up AST by specifying vardecl as sized #1203

Merged
merged 1 commit into from
May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions src/frontend/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,13 @@ type ('e, 's, 'l, 'f) statement =
| Profile of string * 's list
| Block of 's list
| VarDecl of
{ decl_type: 'e Middle.Type.t
{ decl_type: 'e SizedType.t
; transformation: 'e Transformation.t
; identifier: identifier
; initial_value: 'e option
; is_global: bool }
| FunDef of
{ returntype: Middle.UnsizedType.returntype
{ returntype: UnsizedType.returntype
; funname: identifier
; arguments:
(Middle.UnsizedType.autodifftype * Middle.UnsizedType.t * identifier)
Expand Down Expand Up @@ -325,17 +325,16 @@ let rec get_loc_expr (e : untyped_expression) =
e.emeta.loc.end_loc
| FunApp (_, id, _) | CondDistApp (_, id, _) -> id.id_loc.end_loc

let get_loc_dt (t : untyped_expression Type.t) =
let get_loc_dt (t : untyped_expression SizedType.t) =
match t with
| Type.Unsized _ | Sized (SInt | SReal | SComplex) -> None
| Sized
( SVector (_, e)
| SRowVector (_, e)
| SMatrix (_, e, _)
| SComplexVector e
| SComplexRowVector e
| SComplexMatrix (e, _)
| SArray (_, e) ) ->
| SInt | SReal | SComplex -> None
| SVector (_, e)
|SRowVector (_, e)
|SMatrix (_, e, _)
|SComplexVector e
|SComplexRowVector e
|SComplexMatrix (e, _)
|SArray (_, e) ->
Some e.emeta.loc.begin_loc

let get_loc_tf (t : untyped_expression Transformation.t) =
Expand Down
23 changes: 9 additions & 14 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ let rec check_decl var decl_type' decl_id decl_trans smeta adlevel =
[check_id var]
| _ -> []

let check_sizedtype name =
let check_sizedtype name st =
let check x = function
| {Expr.Fixed.pattern= Lit (Int, i); _} when float_of_string i >= 0. -> []
| n ->
Expand Down Expand Up @@ -324,13 +324,11 @@ let check_sizedtype name =
let e = trans_expr s in
let ll, t = sizedtype t in
(check s e @ ll, SizedType.SArray (t, e)) in
function
| Type.Sized st ->
let ll, st = sizedtype st in
(ll, Type.Sized st)
| Unsized ut -> ([], Unsized ut)
let ll, st = sizedtype st in
(ll, Type.Sized st)

let trans_decl {transform_action; dadlevel} smeta decl_type transform identifier
let trans_decl {transform_action; dadlevel} smeta
(decl_type : Ast.typed_expression SizedType.t) transform identifier
initial_value =
let decl_id = identifier.Ast.name in
let rhs = Option.map ~f:trans_expr initial_value in
Expand All @@ -341,7 +339,7 @@ let trans_decl {transform_action; dadlevel} smeta decl_type transform identifier
{ Fixed.pattern= Var decl_id
; meta=
Typed.Meta.create ~adlevel:dadlevel ~loc:smeta
~type_:(Type.to_unsized decl_type)
~type_:(SizedType.to_unsized decl_type)
() } in
let decl =
Stmt.
Expand Down Expand Up @@ -631,7 +629,7 @@ let trans_block ud_dists declc block prog =
match stmt with
| { Ast.stmt=
VarDecl
{ decl_type= Sized type_
{ decl_type= type_
; identifier
; transformation
; initial_value
Expand Down Expand Up @@ -706,10 +704,7 @@ let gather_data (p : Ast.typed_program) =
List.filter_map data ~f:(function
| { stmt=
VarDecl
{ decl_type= Sized sizedtype
; transformation
; identifier= {name; _}
; _ }
{decl_type= sizedtype; transformation; identifier= {name; _}; _}
; _ } ->
Some
( SizedType.map trans_expr sizedtype
Expand All @@ -731,7 +726,7 @@ let trans_prog filename (p : Ast.typed_program) : Program.Typed.t =
let trans_stmt = trans_stmt ud_dists in
let get_name_size s =
match s.Ast.stmt with
| Ast.VarDecl {decl_type= Sized st; identifier; transformation; _} ->
| Ast.VarDecl {decl_type= st; identifier; transformation; _} ->
[(identifier.name, trans_sizedtype st, transformation)]
| _ -> [] in
let input_vars =
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/Canonicalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ let rec parens_stmt ({stmt; smeta} : typed_statement) : typed_statement =
; initial_value= init
; is_global } ->
VarDecl
{ decl_type= Middle.Type.map no_parens d
{ decl_type= Middle.SizedType.map no_parens d
; transformation= Middle.Transformation.map keep_parens t
; identifier
; initial_value= Option.map ~f:no_parens init
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/Info.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ let unsized_basetype_json t =
`Assoc [("type", `String type_); ("dimensions", `Int dim)] in
type_dims t |> to_json

let basetype_dims t = Type.to_unsized t |> unsized_basetype_json
let basetype_dims t = SizedType.to_unsized t |> unsized_basetype_json

let get_var_decl {stmts; _} : t =
`Assoc
Expand Down
99 changes: 46 additions & 53 deletions src/frontend/Pretty_printing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -342,61 +342,54 @@ let pp_bracketed_transform ppf = function
|CholeskyCov | Correlation | Covariance ->
()

let pp_transformed_type ppf (pst, trans) =
let pp_transformed_type ppf (st, trans) =
let open Middle in
match pst with
| Type.Unsized ust ->
(* unsized types are untransformed *)
pf ppf "%a" pp_unsizedtype ust
| Type.Sized st -> (
let pp_possibly_transformed_type ppf (st, trans) =
let sizes_fmt =
match st with
| SizedType.SVector (_, e)
|SRowVector (_, e)
|SComplexVector e
|SComplexRowVector e ->
const (fun ppf -> pf ppf "[%a]" pp_expression) e
| SMatrix (_, e1, e2) | SComplexMatrix (e1, e2) ->
const
(fun ppf -> pf ppf "[%a, %a]" pp_expression e1 pp_expression)
e2
| SArray _ | SInt | SReal | SComplex -> nop in
let cov_sizes_fmt =
match st with
| SMatrix (_, e1, e2) ->
if e1 = e2 then const (fun ppf -> pf ppf "[%a]" pp_expression) e1
else
const
(fun ppf -> pf ppf "[%a, %a]" pp_expression e1 pp_expression)
e2
| _ -> nop in
match trans with
| Transformation.Identity -> pf ppf "%a" pp_sizedtype st
| Lower _ | Upper _ | LowerUpper _ | Offset _ | Multiplier _
|OffsetMultiplier _ ->
pf ppf "%a%a%a" pp_unsizedtype (SizedType.to_unsized st)
pp_bracketed_transform trans sizes_fmt ()
| Ordered -> pf ppf "ordered%a" sizes_fmt ()
| PositiveOrdered -> pf ppf "positive_ordered%a" sizes_fmt ()
| Simplex -> pf ppf "simplex%a" sizes_fmt ()
| UnitVector -> pf ppf "unit_vector%a" sizes_fmt ()
| CholeskyCorr -> pf ppf "cholesky_factor_corr%a" cov_sizes_fmt ()
| CholeskyCov -> pf ppf "cholesky_factor_cov%a" cov_sizes_fmt ()
| Correlation -> pf ppf "corr_matrix%a" cov_sizes_fmt ()
| Covariance -> pf ppf "cov_matrix%a" cov_sizes_fmt () in
let pp_possibly_transformed_type ppf (st, trans) =
let sizes_fmt =
match st with
(* array goes before something like cov_matrix *)
| Middle.SizedType.SArray _ ->
let ty, ixs = unwind_sized_array_type st in
let ({emeta= {loc= {end_loc; _}; _}; _} : untyped_expression) =
List.last_exn ixs in
let ({emeta= {loc= {begin_loc; _}; _}; _} : untyped_expression) =
List.hd_exn ixs in
pf ppf "array[@[%a@]]@ %a" pp_list_of_expression
(ixs, {begin_loc; end_loc})
pp_possibly_transformed_type (ty, trans)
| _ -> pf ppf "%a" pp_possibly_transformed_type (st, trans) )
| SizedType.SVector (_, e)
|SRowVector (_, e)
|SComplexVector e
|SComplexRowVector e ->
const (fun ppf -> pf ppf "[%a]" pp_expression) e
| SMatrix (_, e1, e2) | SComplexMatrix (e1, e2) ->
const (fun ppf -> pf ppf "[%a, %a]" pp_expression e1 pp_expression) e2
| SArray _ | SInt | SReal | SComplex -> nop in
let cov_sizes_fmt =
match st with
| SMatrix (_, e1, e2) ->
if e1 = e2 then const (fun ppf -> pf ppf "[%a]" pp_expression) e1
else
const
(fun ppf -> pf ppf "[%a, %a]" pp_expression e1 pp_expression)
e2
| _ -> nop in
match trans with
| Transformation.Identity -> pf ppf "%a" pp_sizedtype st
| Lower _ | Upper _ | LowerUpper _ | Offset _ | Multiplier _
|OffsetMultiplier _ ->
pf ppf "%a%a%a" pp_unsizedtype (SizedType.to_unsized st)
pp_bracketed_transform trans sizes_fmt ()
| Ordered -> pf ppf "ordered%a" sizes_fmt ()
| PositiveOrdered -> pf ppf "positive_ordered%a" sizes_fmt ()
| Simplex -> pf ppf "simplex%a" sizes_fmt ()
| UnitVector -> pf ppf "unit_vector%a" sizes_fmt ()
| CholeskyCorr -> pf ppf "cholesky_factor_corr%a" cov_sizes_fmt ()
| CholeskyCov -> pf ppf "cholesky_factor_cov%a" cov_sizes_fmt ()
| Correlation -> pf ppf "corr_matrix%a" cov_sizes_fmt ()
| Covariance -> pf ppf "cov_matrix%a" cov_sizes_fmt () in
match st with
(* array goes before something like cov_matrix *)
| Middle.SizedType.SArray _ ->
let ty, ixs = unwind_sized_array_type st in
let ({emeta= {loc= {end_loc; _}; _}; _} : untyped_expression) =
List.last_exn ixs in
let ({emeta= {loc= {begin_loc; _}; _}; _} : untyped_expression) =
List.hd_exn ixs in
pf ppf "array[@[%a@]]@ %a" pp_list_of_expression
(ixs, {begin_loc; end_loc})
pp_possibly_transformed_type (ty, trans)
| _ -> pf ppf "%a" pp_possibly_transformed_type (st, trans)

let rec pp_indent_unless_block ppf ((s : untyped_statement), loc) =
match s.stmt with
Expand Down
14 changes: 4 additions & 10 deletions src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1495,7 +1495,7 @@ and check_var_decl loc cf tenv sized_ty trans id init is_global =
verify_transformed_param_ty loc cf is_global unsized_type ;
let stmt =
VarDecl
{ decl_type= Sized checked_type
{ decl_type= checked_type
; transformation= checked_trans
; identifier= id
; initial_value= tinit
Expand Down Expand Up @@ -1684,16 +1684,10 @@ and check_statement (cf : context_flags_record) (tenv : Env.t)
| ForEach (id, e, s) -> (tenv, check_foreach loc cf tenv id e s)
| Block stmts -> (tenv, check_block loc cf tenv stmts)
| Profile (name, vdsl) -> (tenv, check_profile loc cf tenv name vdsl)
| VarDecl {decl_type= Unsized _; _} ->
(* currently unallowed by parser *)
Common.FatalError.fatal_error_msg
[%message "Don't support unsized declarations yet."]
(* these two are special in that they're allowed to change the type environment *)
| VarDecl
{decl_type= Sized st; transformation; identifier; initial_value; is_global}
->
check_var_decl loc cf tenv st transformation identifier initial_value
is_global
| VarDecl {decl_type; transformation; identifier; initial_value; is_global} ->
check_var_decl loc cf tenv decl_type transformation identifier
initial_value is_global
| FunDef {returntype; funname; arguments; body} ->
check_fundef loc cf tenv returntype funname arguments body

Expand Down
4 changes: 2 additions & 2 deletions src/frontend/parser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ decl(type_rule, rhs):
(fun ~is_global ->
[{ stmt=
VarDecl {
decl_type= Sized (reducearray (fst ty, dims))
decl_type= (reducearray (fst ty, dims))
; transformation= snd ty
; identifier= id
; initial_value= rhs_opt
Expand All @@ -362,7 +362,7 @@ decl(type_rule, rhs):
List.map vs ~f:(fun (id, rhs_opt) ->
{ stmt=
VarDecl {
decl_type= Sized (reducearray (fst ty, dims))
decl_type= (reducearray (fst ty, dims))
; transformation= snd ty
; identifier= id
; initial_value= rhs_opt
Expand Down
16 changes: 7 additions & 9 deletions test/unit/Parse_tests.ml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ let%expect_test "parse minus unary" =
(modelblock
(((stmts
(((stmt
(VarDecl (decl_type (Sized SReal)) (transformation Identity)
(VarDecl (decl_type SReal) (transformation Identity)
(identifier ((name x) (id_loc <opaque>))) (initial_value ())
(is_global false)))
(smeta ((loc <opaque>))))
Expand Down Expand Up @@ -115,7 +115,7 @@ let%expect_test "parse unary over binary" =
(modelblock
(((stmts
(((stmt
(VarDecl (decl_type (Sized SReal)) (transformation Identity)
(VarDecl (decl_type SReal) (transformation Identity)
(identifier ((name x) (id_loc <opaque>)))
(initial_value
(((expr
Expand Down Expand Up @@ -158,9 +158,8 @@ let%expect_test "parse indices, two different colons" =
(((stmt
(VarDecl
(decl_type
(Sized
(SMatrix AoS ((expr (IntNumeral 5)) (emeta ((loc <opaque>))))
((expr (IntNumeral 5)) (emeta ((loc <opaque>)))))))
(SMatrix AoS ((expr (IntNumeral 5)) (emeta ((loc <opaque>))))
((expr (IntNumeral 5)) (emeta ((loc <opaque>))))))
(transformation Identity) (identifier ((name x) (id_loc <opaque>)))
(initial_value ()) (is_global false)))
(smeta ((loc <opaque>))))
Expand Down Expand Up @@ -386,10 +385,9 @@ let%expect_test "parse crazy truncation example" =
(((stmt
(VarDecl
(decl_type
(Sized
(SArray
(SArray SReal ((expr (IntNumeral 1)) (emeta ((loc <opaque>)))))
((expr (IntNumeral 1)) (emeta ((loc <opaque>)))))))
(SArray
(SArray SReal ((expr (IntNumeral 1)) (emeta ((loc <opaque>)))))
((expr (IntNumeral 1)) (emeta ((loc <opaque>))))))
(transformation Identity) (identifier ((name T) (id_loc <opaque>)))
(initial_value
(((expr
Expand Down