-
-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ode_adjoint_tol_ctl
#900
Changes from all commits
52d6082
b83b39f
9257708
c61c2b5
1f2d3c7
bfccd70
c32cf54
ba136fa
1b6ad32
ea3a31b
89f8026
b3bd7b3
627471a
eae675b
426b7b5
ef40d65
25d939e
cee75fa
bb74307
1d0bd17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,9 +95,25 @@ let reduce_sum_slice_types = | |
List.concat (List.map ~f:base_slice_type reduce_sum_allowed_dimensionalities) | ||
|
||
(* Variadic ODE *) | ||
let variadic_ode_adjoint_ctl_tol_arg_types = | ||
[ (UnsizedType.DataOnly, UnsizedType.UReal) | ||
(* real relative_tolerance_forward *) | ||
; (DataOnly, UVector) (* vector absolute_tolerance_forward *) | ||
; (DataOnly, UReal) (* real relative_tolerance_backward *) | ||
; (DataOnly, UVector) (* real absolute_tolerance_backward *) | ||
; (DataOnly, UReal) (* real relative_tolerance_quadrature *) | ||
; (DataOnly, UReal) (* real absolute_tolerance_quadrature *) | ||
; (DataOnly, UInt) (* int max_num_steps *) | ||
; (DataOnly, UInt) (* int num_steps_between_checkpoints *) | ||
; (DataOnly, UInt) (* int interpolation_polynomial *) | ||
; (DataOnly, UInt) (* int solver_forward *) | ||
; (DataOnly, UInt) | ||
(* int solver_backward *) | ||
] | ||
|
||
let variadic_ode_tol_arg_types = | ||
[ (UnsizedType.AutoDiffable, UnsizedType.UReal) | ||
; (AutoDiffable, UReal); (DataOnly, UInt) ] | ||
[ (UnsizedType.DataOnly, UnsizedType.UReal) | ||
; (DataOnly, UReal); (DataOnly, UInt) ] | ||
Comment on lines
-99
to
+116
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are for nonadjoint ODEs. Why did they change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bug currently. I should have mention it in the PR comment, sorry. The tolerances in the nonadjoint should not allow autodiffable arguments as that causes C++ not to compile. This was reported by @bbbales2 in the Math PR but that PR has 500+ comments and I cant find the exact comment right now. This would now error in stanc3 which is better I guess. |
||
|
||
let variadic_ode_mandatory_arg_types = | ||
[ (UnsizedType.AutoDiffable, UnsizedType.UVector) | ||
|
@@ -160,17 +176,23 @@ let full_lpmf = [Lpmf; Rng; Ccdf; Cdf] | |
let reduce_sum_functions = | ||
String.Set.of_list ["reduce_sum"; "reduce_sum_static"] | ||
|
||
let variadic_ode_functions = | ||
let variadic_ode_adjoint_fn = "ode_adjoint_tol_ctl" | ||
|
||
let variadic_ode_nonadjoint_fns = | ||
String.Set.of_list | ||
[ "ode_bdf_tol"; "ode_rk45_tol"; "ode_adams_tol"; "ode_bdf"; "ode_rk45" | ||
; "ode_adams"; "ode_ckrk"; "ode_ckrk_tol" ] | ||
|
||
let ode_tolerances_suffix = "_tol" | ||
let is_reduce_sum_fn f = Set.mem reduce_sum_functions f | ||
let is_variadic_ode_fn f = Set.mem variadic_ode_functions f | ||
let is_variadic_ode_nonadjoint_fn f = Set.mem variadic_ode_nonadjoint_fns f | ||
|
||
let is_variadic_ode_fn f = | ||
Set.mem variadic_ode_nonadjoint_fns f || f = variadic_ode_adjoint_fn | ||
|
||
let is_variadic_ode_tol_fn f = | ||
is_variadic_ode_fn f && String.is_suffix f ~suffix:ode_tolerances_suffix | ||
let is_variadic_ode_nonadjoint_tol_fn f = | ||
is_variadic_ode_nonadjoint_fn f | ||
&& String.is_suffix f ~suffix:ode_tolerances_suffix | ||
|
||
let distributions = | ||
[ (full_lpmf, "beta_binomial", [DVInt; DVInt; DVReal; DVReal]) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
functions { | ||
vector f_0_arg(real t, vector z) { | ||
return z; | ||
} | ||
} | ||
|
||
data { | ||
int N; | ||
int M; | ||
int i; | ||
} | ||
|
||
transformed data { | ||
real rel_tol_f; | ||
vector[N] abs_tol_f; | ||
real rel_tol_b; | ||
vector[N] abs_tol_b; | ||
real abs_tol_q; | ||
int max_num_steps; | ||
int num_checkpoints; | ||
int interpolation_polynomial; | ||
int solver_f; | ||
int solver_b; | ||
} | ||
|
||
parameters { | ||
real y; | ||
|
||
vector[N] y0; | ||
real t0; | ||
array[N] real times; | ||
real rel_tol_q; | ||
} | ||
|
||
transformed parameters { | ||
array[M] vector[N] z; | ||
|
||
z = ode_adjoint_tol_ctl(f_0_arg, y0, t0, times, rel_tol_f, abs_tol_f, rel_tol_b, abs_tol_b, rel_tol_q, abs_tol_q, | ||
max_num_steps, num_checkpoints, interpolation_polynomial, solver_f, solver_b); | ||
} | ||
|
||
model { | ||
y ~ normal(0, 1); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
functions { | ||
vector f_0_arg(real t, vector z) { | ||
return z; | ||
} | ||
} | ||
|
||
data { | ||
int N; | ||
int M; | ||
int i; | ||
} | ||
|
||
transformed data { | ||
real rel_tol_f; | ||
vector[N] abs_tol_f; | ||
vector[N] abs_tol_b; | ||
real abs_tol_q; | ||
real rel_tol_q; | ||
int max_num_steps; | ||
int num_checkpoints; | ||
int interpolation_polynomial; | ||
int solver_f; | ||
int solver_b; | ||
} | ||
|
||
parameters { | ||
real y; | ||
|
||
vector[N] y0; | ||
real t0; | ||
array[N] real times; | ||
real rel_tol_b; | ||
} | ||
|
||
transformed parameters { | ||
array[M] vector[N] z; | ||
|
||
z = ode_adjoint_tol_ctl(f_0_arg, y0, t0, times, rel_tol_f, abs_tol_f, rel_tol_b, abs_tol_b, rel_tol_q, abs_tol_q, | ||
max_num_steps, num_checkpoints, interpolation_polynomial, solver_f, solver_b); | ||
} | ||
|
||
model { | ||
y ~ normal(0, 1); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
functions { | ||
vector f_0_arg(real t, vector z) { | ||
return z; | ||
} | ||
} | ||
|
||
data { | ||
int N; | ||
int M; | ||
int i; | ||
} | ||
|
||
transformed data { | ||
real rel_tol_f; | ||
vector[N] abs_tol_b; | ||
real rel_tol_b; | ||
real abs_tol_q; | ||
real rel_tol_q; | ||
int max_num_steps; | ||
int num_checkpoints; | ||
int interpolation_polynomial; | ||
int solver_f; | ||
int solver_b; | ||
} | ||
|
||
parameters { | ||
real y; | ||
|
||
vector[N] y0; | ||
real t0; | ||
array[N] real times; | ||
vector[N] abs_tol_f; | ||
} | ||
|
||
transformed parameters { | ||
array[M] vector[N] z; | ||
|
||
z = ode_adjoint_tol_ctl(f_0_arg, y0, t0, times, rel_tol_f, abs_tol_f, rel_tol_b, abs_tol_b, rel_tol_q, abs_tol_q, | ||
max_num_steps, num_checkpoints, interpolation_polynomial, solver_f, solver_b); | ||
} | ||
|
||
model { | ||
y ~ normal(0, 1); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is so the error also prints
data
for all variadic ODE types.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO the signatures are pretty much unreadable either way but I suppose the data markers help if that was what caused the error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree that the prints of the two large signatures is not great... But that is a different issue I guess as its not limited to just the variadic ODEs.