Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rok-cesnovar committed May 18, 2021
1 parent 25d939e commit cee75fa
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 12 deletions.
7 changes: 2 additions & 5 deletions src/frontend/Semantic_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ let semantic_check_variadic_ode ~is_cond_dist ~loc id es =
let optional_tol_mandatory_args =
if Stan_math_signatures.variadic_ode_adjoint_fn = id.name then
Stan_math_signatures.variadic_ode_adjoint_ctl_tol_arg_types
else if Stan_math_signatures.is_nonadjoint_variadic_ode_tol_fn id.name then
else if Stan_math_signatures.is_variadic_ode_nonadjoint_tol_fn id.name then
Stan_math_signatures.variadic_ode_tol_arg_types
else []
in
Expand All @@ -396,11 +396,8 @@ let semantic_check_variadic_ode ~is_cond_dist ~loc id es =
{type_= UnsizedType.UFun (fun_args, ReturnType return_type, FnPlain); _}; _
}
:: args ->
let num_of_mandatory_args =
List.length optional_tol_mandatory_args + 3
in
let mandatory_args, variadic_args =
List.split_n args num_of_mandatory_args
List.split_n args (List.length mandatory_arg_types)
in
let mandatory_fun_args, variadic_fun_args = List.split_n fun_args 2 in
if
Expand Down
4 changes: 2 additions & 2 deletions src/middle/Semantic_error.ml
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ module TypeError = struct
let optional_tol_args =
if Stan_math_signatures.variadic_ode_adjoint_fn = name then
Stan_math_signatures.variadic_ode_adjoint_ctl_tol_arg_types
else if Stan_math_signatures.is_nonadjoint_variadic_ode_tol_fn name
else if Stan_math_signatures.is_variadic_ode_nonadjoint_tol_fn name
then Stan_math_signatures.variadic_ode_tol_arg_types
else []
in
Expand All @@ -182,7 +182,7 @@ module TypeError = struct
let optional_tol_args =
if Stan_math_signatures.variadic_ode_adjoint_fn = name then
types Stan_math_signatures.variadic_ode_adjoint_ctl_tol_arg_types
else if Stan_math_signatures.is_nonadjoint_variadic_ode_tol_fn name
else if Stan_math_signatures.is_variadic_ode_nonadjoint_tol_fn name
then types Stan_math_signatures.variadic_ode_tol_arg_types
else []
in
Expand Down
10 changes: 5 additions & 5 deletions src/middle/Stan_math_signatures.ml
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,20 @@ let reduce_sum_functions =

let variadic_ode_adjoint_fn = "ode_adjoint_tol_ctl"

let nonadjoint_variadic_ode_fns =
let variadic_ode_nonadjoint_fns =
String.Set.of_list
[ "ode_bdf_tol"; "ode_rk45_tol"; "ode_adams_tol"; "ode_bdf"; "ode_rk45"
; "ode_adams"; "ode_ckrk"; "ode_ckrk_tol" ]

let ode_tolerances_suffix = "_tol"
let is_reduce_sum_fn f = Set.mem reduce_sum_functions f
let is_nonadjoint_variadic_ode_fn f = Set.mem nonadjoint_variadic_ode_fns f
let is_variadic_ode_nonadjoint_fn f = Set.mem variadic_ode_nonadjoint_fns f

let is_variadic_ode_fn f =
Set.mem nonadjoint_variadic_ode_fns f || f = variadic_ode_adjoint_fn
Set.mem variadic_ode_nonadjoint_fns f || f = variadic_ode_adjoint_fn

let is_nonadjoint_variadic_ode_tol_fn f =
is_nonadjoint_variadic_ode_fn f
let is_variadic_ode_nonadjoint_tol_fn f =
is_variadic_ode_nonadjoint_fn f
&& String.is_suffix f ~suffix:ode_tolerances_suffix

let distributions =
Expand Down

0 comments on commit cee75fa

Please sign in to comment.