Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ode_adjoint_tol_ctl #900

Merged
merged 20 commits into from
May 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions src/frontend/Semantic_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,9 @@ let semantic_check_reduce_sum ~is_cond_dist ~loc id es =

let semantic_check_variadic_ode ~is_cond_dist ~loc id es =
let optional_tol_mandatory_args =
if Stan_math_signatures.is_variadic_ode_tol_fn id.name then
if Stan_math_signatures.variadic_ode_adjoint_fn = id.name then
Stan_math_signatures.variadic_ode_adjoint_ctl_tol_arg_types
else if Stan_math_signatures.is_variadic_ode_nonadjoint_tol_fn id.name then
Stan_math_signatures.variadic_ode_tol_arg_types
else []
in
Expand All @@ -378,8 +380,7 @@ let semantic_check_variadic_ode ~is_cond_dist ~loc id es =
@ optional_tol_mandatory_args
in
let generic_variadic_ode_semantic_error =
Semantic_error.illtyped_variadic_ode loc id.name
(List.map ~f:type_of_expr_typed es)
Semantic_error.illtyped_variadic_ode loc id.name (List.map ~f:arg_type es)
[]
|> Validate.error
in
Expand All @@ -395,11 +396,8 @@ let semantic_check_variadic_ode ~is_cond_dist ~loc id es =
{type_= UnsizedType.UFun (fun_args, ReturnType return_type, FnPlain); _}; _
}
:: args ->
let num_of_mandatory_args =
if Stan_math_signatures.is_variadic_ode_tol_fn id.name then 6 else 3
in
let mandatory_args, variadic_args =
List.split_n args num_of_mandatory_args
List.split_n args (List.length mandatory_arg_types)
in
let mandatory_fun_args, variadic_fun_args = List.split_n fun_args 2 in
if
Expand All @@ -417,8 +415,7 @@ let semantic_check_variadic_ode ~is_cond_dist ~loc id es =
|> Validate.ok
else
Semantic_error.illtyped_variadic_ode loc id.name
(List.map ~f:type_of_expr_typed es)
fun_args
(List.map ~f:arg_type es) fun_args
|> Validate.error
else generic_variadic_ode_semantic_error
| _ -> generic_variadic_ode_semantic_error
Expand Down
58 changes: 32 additions & 26 deletions src/middle/Semantic_error.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ module TypeError = struct
| IllTypedReduceSumGeneric of string * UnsizedType.t list
| IllTypedVariadicODE of
string
* UnsizedType.t list
* (UnsizedType.autodifftype * UnsizedType.t) list
* (UnsizedType.autodifftype * UnsizedType.t) list
| ReturningFnExpectedNonReturningFound of string
| ReturningFnExpectedNonFnFound of string
Expand Down Expand Up @@ -154,19 +154,21 @@ module TypeError = struct
Fmt.(list UnsizedType.pp ~sep:comma)
arg_tys
| IllTypedVariadicODE (name, arg_tys, args) ->
let types x = List.map ~f:snd x in
let optional_tol_args =
if Stan_math_signatures.is_variadic_ode_tol_fn name then
types Stan_math_signatures.variadic_ode_tol_arg_types
if Stan_math_signatures.variadic_ode_adjoint_fn = name then
Stan_math_signatures.variadic_ode_adjoint_ctl_tol_arg_types
else if Stan_math_signatures.is_variadic_ode_nonadjoint_tol_fn name
then Stan_math_signatures.variadic_ode_tol_arg_types
else []
in
let generate_ode_sig =
[ UnsizedType.UFun
( Stan_math_signatures.variadic_ode_mandatory_fun_args @ args
, ReturnType Stan_math_signatures.variadic_ode_fun_return_type
, FnPlain ) ]
@ types Stan_math_signatures.variadic_ode_mandatory_arg_types
@ optional_tol_args @ types args
[ ( UnsizedType.AutoDiffable
, UnsizedType.UFun
( Stan_math_signatures.variadic_ode_mandatory_fun_args @ args
, ReturnType Stan_math_signatures.variadic_ode_fun_return_type
, FnPlain ) ) ]
@ Stan_math_signatures.variadic_ode_mandatory_arg_types
@ optional_tol_args @ args
in
(* This function is used to generate the generic signature for variadic ODEs,
i.e. with ... representing the variadic parts of the signature.
Expand All @@ -177,20 +179,24 @@ module TypeError = struct
(with explicit types for variadic args). *)
let variadic_ode_generic_signature =
let optional_tol_args =
if Stan_math_signatures.is_variadic_ode_tol_fn name then
types Stan_math_signatures.variadic_ode_tol_arg_types
if Stan_math_signatures.variadic_ode_adjoint_fn = name then
Stan_math_signatures.variadic_ode_adjoint_ctl_tol_arg_types
else if Stan_math_signatures.is_variadic_ode_nonadjoint_tol_fn name
then Stan_math_signatures.variadic_ode_tol_arg_types
else []
in
match
( types Stan_math_signatures.variadic_ode_mandatory_arg_types
, types Stan_math_signatures.variadic_ode_mandatory_fun_args )
( Stan_math_signatures.variadic_ode_mandatory_arg_types
, Stan_math_signatures.variadic_ode_mandatory_fun_args )
with
| arg0 :: arg1 :: arg2 :: _, fun_arg0 :: fun_arg1 :: _ ->
Fmt.strf "(%a, %a, ...) => %a, %a, %a, %a, %a ...\n"
UnsizedType.pp fun_arg0 UnsizedType.pp fun_arg1 UnsizedType.pp
Fmt.strf "@[<hov 1>(%a, %a, ...) => %a, %a, %a, %a, %a ...@]"
UnsizedType.pp_fun_arg fun_arg0 UnsizedType.pp_fun_arg fun_arg1
UnsizedType.pp
Stan_math_signatures.variadic_ode_fun_return_type
UnsizedType.pp arg0 UnsizedType.pp arg1 UnsizedType.pp arg2
Fmt.(list UnsizedType.pp ~sep:comma)
UnsizedType.pp_fun_arg arg0 UnsizedType.pp_fun_arg arg1
UnsizedType.pp_fun_arg arg2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is so the error also prints data for all variadic ODE types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the signatures are pretty much unreadable either way but I suppose the data markers help if that was what caused the error.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree that the prints of the two large signatures is not great... But that is a different issue I guess as its not limited to just the variadic ODEs.

Fmt.(list UnsizedType.pp_fun_arg ~sep:comma)
optional_tol_args
| _ ->
raise_s
Expand All @@ -200,26 +206,26 @@ module TypeError = struct
supplied to the variadic ODE function has exactly two \
mandatory arguments."]
in
if List.length args = 0 then
if List.length args <> 0 then
Fmt.pf ppf
"Ill-typed arguments supplied to function '%s'. Expected \
arguments:@[<h>%a@]\n\
@[<h>Instead supplied arguments of incompatible type:\n\
%a@]"
Instead supplied arguments of incompatible type:\n\
@[<h>%a@]"
name
Fmt.(list UnsizedType.pp ~sep:comma)
Fmt.(list UnsizedType.pp_fun_arg ~sep:comma)
generate_ode_sig
Fmt.(list UnsizedType.pp ~sep:comma)
Fmt.(list UnsizedType.pp_fun_arg ~sep:comma)
arg_tys
else
Fmt.pf ppf
"Ill-typed arguments supplied to function '%s'. @[<h>Available \
signatures:\n\
%s.@]\n\
@[<h>%s@]\n\
@[<h>Instead supplied arguments of incompatible type:\n\
%a.@]"
%a@]"
name variadic_ode_generic_signature
Fmt.(list UnsizedType.pp ~sep:comma)
Fmt.(list UnsizedType.pp_fun_arg ~sep:comma)
arg_tys
| NotIndexable (ut, nidcs) ->
Fmt.pf ppf
Expand Down
2 changes: 1 addition & 1 deletion src/middle/Semantic_error.mli
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ val illtyped_reduce_sum_generic :
val illtyped_variadic_ode :
Location_span.t
-> string
-> UnsizedType.t list
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> t

Expand Down
34 changes: 28 additions & 6 deletions src/middle/Stan_math_signatures.ml
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,25 @@ let reduce_sum_slice_types =
List.concat (List.map ~f:base_slice_type reduce_sum_allowed_dimensionalities)

(* Variadic ODE *)
let variadic_ode_adjoint_ctl_tol_arg_types =
[ (UnsizedType.DataOnly, UnsizedType.UReal)
(* real relative_tolerance_forward *)
; (DataOnly, UVector) (* vector absolute_tolerance_forward *)
; (DataOnly, UReal) (* real relative_tolerance_backward *)
; (DataOnly, UVector) (* real absolute_tolerance_backward *)
; (DataOnly, UReal) (* real relative_tolerance_quadrature *)
; (DataOnly, UReal) (* real absolute_tolerance_quadrature *)
; (DataOnly, UInt) (* int max_num_steps *)
; (DataOnly, UInt) (* int num_steps_between_checkpoints *)
; (DataOnly, UInt) (* int interpolation_polynomial *)
; (DataOnly, UInt) (* int solver_forward *)
; (DataOnly, UInt)
(* int solver_backward *)
]

let variadic_ode_tol_arg_types =
[ (UnsizedType.AutoDiffable, UnsizedType.UReal)
; (AutoDiffable, UReal); (DataOnly, UInt) ]
[ (UnsizedType.DataOnly, UnsizedType.UReal)
; (DataOnly, UReal); (DataOnly, UInt) ]
Comment on lines -99 to +116
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are for nonadjoint ODEs. Why did they change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug currently. I should have mention it in the PR comment, sorry.

The tolerances in the nonadjoint should not allow autodiffable arguments as that causes C++ not to compile. This was reported by @bbbales2 in the Math PR but that PR has 500+ comments and I cant find the exact comment right now.

This would now error in stanc3 which is better I guess.


let variadic_ode_mandatory_arg_types =
[ (UnsizedType.AutoDiffable, UnsizedType.UVector)
Expand Down Expand Up @@ -160,17 +176,23 @@ let full_lpmf = [Lpmf; Rng; Ccdf; Cdf]
let reduce_sum_functions =
String.Set.of_list ["reduce_sum"; "reduce_sum_static"]

let variadic_ode_functions =
let variadic_ode_adjoint_fn = "ode_adjoint_tol_ctl"

let variadic_ode_nonadjoint_fns =
String.Set.of_list
[ "ode_bdf_tol"; "ode_rk45_tol"; "ode_adams_tol"; "ode_bdf"; "ode_rk45"
; "ode_adams"; "ode_ckrk"; "ode_ckrk_tol" ]

let ode_tolerances_suffix = "_tol"
let is_reduce_sum_fn f = Set.mem reduce_sum_functions f
let is_variadic_ode_fn f = Set.mem variadic_ode_functions f
let is_variadic_ode_nonadjoint_fn f = Set.mem variadic_ode_nonadjoint_fns f

let is_variadic_ode_fn f =
Set.mem variadic_ode_nonadjoint_fns f || f = variadic_ode_adjoint_fn

let is_variadic_ode_tol_fn f =
is_variadic_ode_fn f && String.is_suffix f ~suffix:ode_tolerances_suffix
let is_variadic_ode_nonadjoint_tol_fn f =
is_variadic_ode_nonadjoint_fn f
&& String.is_suffix f ~suffix:ode_tolerances_suffix

let distributions =
[ (full_lpmf, "beta_binomial", [DVInt; DVInt; DVReal; DVReal])
Expand Down
28 changes: 26 additions & 2 deletions src/stan_math_backend/Expression_gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -335,13 +335,37 @@ and gen_fun_app suffix ppf fname es =
| true, x, f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: tl
when Stan_math_signatures.is_variadic_ode_fn x
&& String.is_suffix fname
~suffix:Stan_math_signatures.ode_tolerances_suffix ->
~suffix:Stan_math_signatures.ode_tolerances_suffix
&& not (Stan_math_signatures.variadic_ode_adjoint_fn = x) ->
( fname
, f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: msgs
:: tl )
| true, x, f :: y0 :: t0 :: ts :: tl
when Stan_math_signatures.is_variadic_ode_fn x ->
when Stan_math_signatures.is_variadic_ode_fn x
&& not (Stan_math_signatures.variadic_ode_adjoint_fn = x) ->
(fname, f :: y0 :: t0 :: ts :: msgs :: tl)
| ( true
, x
, f
:: y0
:: t0
:: ts
:: rel_tol
:: abs_tol
:: rel_tol_b
:: abs_tol_b
:: rel_tol_q
:: abs_tol_q
:: max_num_steps
:: num_checkpoints
:: interpolation_polynomial
:: solver_f :: solver_b :: tl )
when Stan_math_signatures.variadic_ode_adjoint_fn = x ->
( fname
, f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: rel_tol_b :: abs_tol_b
:: rel_tol_q :: abs_tol_q :: max_num_steps :: num_checkpoints
:: interpolation_polynomial :: solver_f :: solver_b :: msgs :: tl
)
| ( true
, "map_rect"
, {pattern= FunApp ((UserDefined (f, _) | StanLib (f, _)), _); _} :: tl
Expand Down
44 changes: 44 additions & 0 deletions test/integration/bad/ode/bad_var_tol_1.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
functions {
vector f_0_arg(real t, vector z) {
return z;
}
}

data {
int N;
int M;
int i;
}

transformed data {
real rel_tol_f;
vector[N] abs_tol_f;
real rel_tol_b;
vector[N] abs_tol_b;
real abs_tol_q;
int max_num_steps;
int num_checkpoints;
int interpolation_polynomial;
int solver_f;
int solver_b;
}

parameters {
real y;

vector[N] y0;
real t0;
array[N] real times;
real rel_tol_q;
}

transformed parameters {
array[M] vector[N] z;

z = ode_adjoint_tol_ctl(f_0_arg, y0, t0, times, rel_tol_f, abs_tol_f, rel_tol_b, abs_tol_b, rel_tol_q, abs_tol_q,
max_num_steps, num_checkpoints, interpolation_polynomial, solver_f, solver_b);
}

model {
y ~ normal(0, 1);
}
44 changes: 44 additions & 0 deletions test/integration/bad/ode/bad_var_tol_2.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
functions {
vector f_0_arg(real t, vector z) {
return z;
}
}

data {
int N;
int M;
int i;
}

transformed data {
real rel_tol_f;
vector[N] abs_tol_f;
vector[N] abs_tol_b;
real abs_tol_q;
real rel_tol_q;
int max_num_steps;
int num_checkpoints;
int interpolation_polynomial;
int solver_f;
int solver_b;
}

parameters {
real y;

vector[N] y0;
real t0;
array[N] real times;
real rel_tol_b;
}

transformed parameters {
array[M] vector[N] z;

z = ode_adjoint_tol_ctl(f_0_arg, y0, t0, times, rel_tol_f, abs_tol_f, rel_tol_b, abs_tol_b, rel_tol_q, abs_tol_q,
max_num_steps, num_checkpoints, interpolation_polynomial, solver_f, solver_b);
}

model {
y ~ normal(0, 1);
}
44 changes: 44 additions & 0 deletions test/integration/bad/ode/bad_var_tol_3.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
functions {
vector f_0_arg(real t, vector z) {
return z;
}
}

data {
int N;
int M;
int i;
}

transformed data {
real rel_tol_f;
vector[N] abs_tol_b;
real rel_tol_b;
real abs_tol_q;
real rel_tol_q;
int max_num_steps;
int num_checkpoints;
int interpolation_polynomial;
int solver_f;
int solver_b;
}

parameters {
real y;

vector[N] y0;
real t0;
array[N] real times;
vector[N] abs_tol_f;
}

transformed parameters {
array[M] vector[N] z;

z = ode_adjoint_tol_ctl(f_0_arg, y0, t0, times, rel_tol_f, abs_tol_f, rel_tol_b, abs_tol_b, rel_tol_q, abs_tol_q,
max_num_steps, num_checkpoints, interpolation_polynomial, solver_f, solver_b);
}

model {
y ~ normal(0, 1);
}
Loading