- Feature Name: ode-adjoint
- Start Date: 2021-02-02
- RFC PR: (leave this empty)
- Stan Issue: (leave this empty)
Ordinary differential equations (ODEs) are currently expensive to
evaluate in Stan for larger systems. This is due to the multiplicative
scaling of the required computing resources with the ODE system size
The computational cost of the adjoint method scales much more favorable in comparison which is
The advantage of the adjoint methods shows in particular whenever the number of parameters are relatively large in comparison to the number of states. Most importantly, the computational cost grows only linearly in the number of states and parameters (while forward grows multiplicatively). Thus, this method can scale to much larger problems without a multiplicative increasing computational cost. However, the adjoint method is more involved resulting in a bigger overhead than the forward mode method. As a result, the forward method can have advantages for smaller ODE systems and as such both methods remain valuable within Stan.
Stan is currently practically limited to solve problems with ODEs which are small in the number of states and parameters. If either of them gets large, the computational cost explode quickly. With the adjoint ODE solving method Stan will be able to scale to much larger problems involving more states and more parameters. Examples are large pharmacokinetic / pharmacodynamic systems or physiological based pharmacokinetic models or bigger susceptible, infected, and recovered (SIR) models. The adjoint ODE method will grow the computational cost only linearly with states and parameters and therefore modelers can more freely increase the complexity of the ODE model without substantially increasing inference times.
The new solver ode_adjoint
will follow in it's design the variadic
ODE functions. As the new solver is numerically more involved (see
reference level) there are more tuning parameters needed for
the solver. The actual implementation will be provided by the CVODES
package which we already include in Stan.
From a more internal stan math perspective, the key difference to the
existing forward mode solvers will be that the gradients are not any
more pre-computed during the forward sweep of reverse mode autodiff
(AD). Instead, during the forward sweep of reverse mode autodiff, the
ODE is solved only for the N states (without any sensitivities). When
reverse mode AD then performs the backward sweep, the autodiff
adjoints of the parameters of the ODE (input operands to the ODE
function call) are directly computed given the parent autodiff
adjoints as defined by the autodiff call graph. This calculation
involves a backward solve of an adjoint ODE system with
The numerical complexity is higher for the ODE adjoint method in comparison to the forward method. While most of the complexity is handled by CVODES, the numerical procedure seems to require more knowledge about the tuning parameters. At least at the moment it appears not possible to make an easy to use interface of this functionality available without most tuning parameters. We need to first collect some experience before a simpler version can be made available (if that is feasible at all). The tuning parameters exposed as proposed are:
- absolute & relative tolerance forward solve; absolute tolerance
should be a vector of length
$N$ , the number of states - absolute & relative tolerance backward solve (for consistency the absolute tolerances are given as vector)
- absolute & relative tolerance quadrature problem
- maximal number of steps between time-points; same for forward and backward integration
- forward solver: adams / bdf - 1 / 2
- backward solver: adams / bdf - 1 / 2
- checkpointing: number of checkpoints every X steps
- checkpointing: hermite or polynomial approximation - 1 / 2
During an experimental phase of this feature we can hopefully learn
which of these are relevant. In order to make it easy for users to
switch their existing ODE problems to the new solver an
ode_adjoint_tol
function will be provided during the experimental
phase. This function call follows the same conventions as the existing
ode integrators. The tuning parameters set as default will be
described below. The simplified call will be integrated into the main
stan math library if users find this function helpful and a reasonable
default for the tuning parameters can be found during the experimental
phase.
The ODE adjoint method for ODEs is relatively involved
mathematically. The main challenge comes from a mix of different
notation and conventions used in different fields. Here we relate
commonly used notation within Stan math to the user guide of
CVODES,
the underlying library we employ to handle the numerical solution. The
goal of reverse mode autodiff is to calculate the derivative of some
function
\begin{align} \dot{y} &= f(y,t,p) (#eq:odeDef) \ y(t_0) &= y^{(0)}. (#eq:odeInitial) \end{align}
The ODE has
and we must calculate the contribution to the autodiff adjoint of the parameters
which involves for each parameter
In the notation of the CVODES user manual, the function
\begin{align} \dot{\lambda} &= - \left(\frac{\partial f}{\partial y}\right)^* \lambda - \left(\frac{\partial g}{\partial y}\right)^* (#eq:lambdaOde) \ \lambda(T) &= 0. (#eq:lambdaInitial) \end{align}
This ODE is referred to as the backward problem, since we fix the
solution
\begin{equation} \left.\frac{d g}{d p}\right|{t=T} = \mu^*(t_0) , s(t_0) + \left.\frac{\partial g}{\partial p}\right|{t=T} + \int_{t_0}^{T} \mu^* , \frac{\partial f}{\partial p} , dt. (#eq:derivg) \end{equation}
Here
\begin{align} \dot{\mu} &= - \left(\frac{\partial f}{\partial y}\right)^* , \mu (#eq:muODE) \ \mu(T) &= \left(\frac{\partial g}{\partial y}\right)^*_{t=T}. (#eq:muInitial) \end{align}
If we now recall that
$$\left(\frac{\partial g}{\partial y_n}\right)^{t=T} = a{l,y_n}^, $$
which is the input we get during reverse mode autodiff (the conjugation does not apply for Stan math using reals only).
It is important to note that the backward ODE problem requires as input the autodiff adjoint $a_{l,y_n}^$ such that the backward problem must be solved during the reverse pass of reverse mode autodiff. Thus, CVODES is used during the forward pass to solve (forward) the ODE in @ref(eq:odeDef). Once the backward pass is initiated, the autodiff adjoints $a_{l,y_n}^$ are used as the terminal value in order to solve the backward problem of @ref(eq:muODE). Whenever any parameter sensitivities are requested, then we must solve along the backward problem an additional one-dimensional quadrature problem per parameter of the ODE, which is the integrand in equation @ref(eq:derivg).
So far we have only considered the case of a single time-point
In total we need to solve 3 integration problems which consist of a forward ODE problem, a backward ODE problem and a backward quadrature problem. The forward ODE and the backward ODE problem can be solved with either a non-stiff Adams or a stiff BDF solver of CVODES (the choice is independent for each problem). Each of the 3 integration problems have their own relative and absolute tolerance targets.
The proposed function should have the following signature:
vector[] ode_adjoint_tol_ctl(F f,
vector y0,
real t0, real[] times,
real relative_tolerance_forward, vector absolute_tolerance_forward,
real relative_tolerance_backward, vector absolute_tolerance_backward,
real relative_tolerance_quadrature, real absolute_tolerance_quadrature,
int max_num_steps,
int num_steps_between_checkpoints,
int interpolation_polynomial,
int solver_forward, int solver_backward,
T1 arg1, T2 arg2, ...)
The arguments are:
f
- User-defined right hand side of the ODE (dy/dt = f
)y0
- Initial state of the ode solve (y0 = y(t0)
)t0
- Initial time of the ode solvetimes
- Sorted arary of times to which the ode will be solved (each element must be greater than t0)relative_tolerance_forward
- Relative tolerance for forward solve (data)absolute_tolerance_forward
- Absolute tolerance vector for each state for forward solve (data)relative_tolerance_backward
- Relative tolerance for backward solve (data)absolute_tolerance_backward
- Absolute tolerance vector for each state backward solve (data)relative_tolerance_quadrature
- Relative tolerance for backward quadrature (data)absolute_tolerance_quadrature
- Absolute tolerance for backward quadrature (data)max_num_steps
- Maximum number of time-steps to take in integrating the ODE solution between output time points for forward and backward solve (data)num_steps_between_checkpoints
number of steps between checkpointing forward solution (data)interpolation_polynomial
can be 1 for hermite or 2 for polynomial interpolation method of CVODESsolver_forward
solver used for forward ODE problem: 1=Adams, 2=bdfsolver_backward
solver used for backward ODE problem: 1=Adams, 2=bdfarg1, arg2, ...
- Arguments passed unmodified to the ODE right hand side. The typesT1, T2, ...
can be any type, but they must match the types of the matching arguments off
.
In addition a simpler function version which mimics the existing ode interface for tolerances should be made available once more experience has been gained with the new ODE adjoint method. While at the current stage it is premature to make such a simpler interface available, the simpler interface is intended to let users quickly try out the new adjoint interface. As it is currently not clear how the extra control parameters will be set based the function signature will not be made available in the Stan language in the first version of the adjoint ODE solver. Nontheless, once sufficient experience has been gained with the adjoint ODE method the following signature should be made available:
vector[] ode_adjoint_tol(F f,
vector y0,
real t0, real[] times,
real relative_tolerance, real absolute_tolerance,
int max_num_steps,
T1 arg1, T2 arg2, ...)
The arguments are:
f
- User-defined right hand side of the ODE (dy/dt = f
)y0
- Initial state of the ode solve (y0 = y(t0)
)t0
- Initial time of the ode solvetimes
- Sorted array of times to which the ode will be solved (each element must be greater than t0)relative_tolerance
- Relative tolerance for all solves (data)absolute_tolerance
- Absolute tolerance for all solves (data)max_num_steps
- Maximum number of time-steps to take in integrating the ODE solution between output time points for forward and backward solve (data)arg1, arg2, ...
- Arguments passed unmodified to the ODE right hand side. The typesT1, T2, ...
can be any type, but they must match the types of the matching arguments off
.
The best guesses based on preliminary experiments for the remaining tuning parameters are then set as
- same relative tolerance in all problems
absolute_tolerance_forward = absolute_tolerance / 10
absolute_tolerance_backward = absolute_tolerance / 3
absolute_tolerance_quadrature = absolute_tolerance
-
$250$ steps between checkpoints - bdf solver for forward and backward solve
- hermite polynomial method for forward function approximation
The above best guesses are only based on early experiments and will need to be refined. This design doc intentionally does not define when "sufficient" experience has been collected until the simplified function can be made available. However, the above defaults (or an updated set) should be included in the documentation of the adjoint ODE solver.
It's some work to be done. Other than that there are no alternatives to my knowledge to get large ODE systems working in Stan. What we are missing out for now is to exploit the sparsity structure of the ODE. This would allow for more efficient solvers and even larger systems, but this is not possible at the moment to figure out structurally the inter-dependencies of inputs and outputs.
Of note is that the adjoint ODE method is not efficient by construction to calculate the gradient wrt to multiple outputs as needed for a Jacobian. Since for each function output the gradient for all parameters must be calculated by resetting the adjoints back to zero and then rerunning the backward sweep of AD (chain method) once again. Hence, this triggers for each function output gradient a rerun of the backward integration problem, which is not needed for the forward ODE method.
The proposed function signature restricts the absolute tolerances for the quadrature backward problem to be a scalar only as opposed to allowing for a vector absolute tolerance which assigns to each parameter qudarature problem it's own absolute tolerance. This simplifcation is made for usability reasons and the emprical observation that neither any CVODES example codes nor Julia's DifferentialEquations.jl package does facilitate a vector absolute tolerance for the backward quadrature integration. A user may work around this limitation to some extent by scaling the states and the parameters which will affect the order of the gradients accordingly and therefore change the tolerance checks of CVODES whenever a gradient approaches zero. As the lack of the feature of a vector absolute tolerance for the parameter quadratures may potentially pose a limitation for future uses of the adjoint ODE solver, the implementation in stan-math should be written in an manner allowing for an inclusion of this feature at a later time (for example through an overload of the proposed signature).
There is no other analytical ODE solving technique available which scales in a comparable way. The better scalability of the adjoint method enables at a fixed computational resource the possibility to solve larger ODE systems with many parameters and/or states to be solved faster. Thus, a Stan modeler can increase the complexity of an ODE model under study without substantially increasing model runtimes.
The numerical complexities of the adjoint method are rather involved as 3 nested integrations are performed. This makes things somewhat fragile and less robust. What makes the backward integration in particular involved is that the solution of the forward problem must be stored as a continuous function in memory and hence an interpolation of the forward solve is required. This is provided by CVODES via a checkpointing procedure. In summary, we do heavily rely on CVODES infrastructure as these building blocks are rather complex and heavy to craft on our own.
The adjoint sensitivity method is applied within different domains
of science. For example, the method is present in engineering
literature and found it's way into packages like
DifferentialEquations.jl
in Julia. Another noticeable reference is
FATODE
and a discourse post from
Ben.
Another domain where the adjoint sensitivity method is being used is in systems biology. The AMICI toolkit has been build to solve ODE system for large scale systems biology equation systems to be used within optimizer software. The adjoint method is implemented in AMICI using CVODES and has been benchmarked on various problems from systems biology as published (see also this arxiv pre-print).
However, so far the adjoint method is commonly build into special purpose software as the method requires a fixed target functional. The integration of the adjoint ODE sensitivity method into a general purpose autodiff library as Stan math is presumably the first implementation of its kind.
The main unresolved question at this point is to settle on the set of exposed tuning parameters. Currently the proposal is to expose in an experimental version a large super-set of tuning parameters and a version which resembles the enhanced interface with tolerances of the existing ODE solvers. This allows users to quickly try out the adjoint method by merely changing a function name. The suggestion is to collect feedback from the Stan community during an experimental phase of this feature. Once feedback on the utility of the simplified interface and the appropriateness of the set defaults are collected a decision should be done about including this signature in a first version included in the next release. Another goal of the experimental phase is to explore the possibility to weed out some tuning parameters if possible or add others if needed.