Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closures for ODEs #2094

Closed
wants to merge 7 commits into from
Closed

Conversation

nhuurre
Copy link
Collaborator

@nhuurre nhuurre commented Sep 22, 2020

Summary

Adds support for closures. A closure is a callable object that also captures some autodiff variables.

The autodiff API is an extension of the existing tools for variadic arguments. The following functions now support closures

count_vars(...)
deep_copy_vars(...)
accumulate_adjoints(...)
save_varis(...)
zero_adjoints(...)
value_of(...)

These should be enough for any higher-order function that takes variadic arguments to (almost) automatically support closures. So far I've been working with the ODE solvers.

Tests

There's a couple of new tests for ode_rk45.

Side Effects

None, I think.

Release notes

Add basic API for closure objects.

Checklist

  • Math issue Implement closures #2197

  • Copyright holder: Niko Huurre

    The copyright holder is typically you or your assignee, such as a university or company. By submitting this pull request, the copyright holder is agreeing to the license the submitted work under the following licenses:
    - Code: BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
    - Documentation: CC-BY 4.0 (https://creativecommons.org/licenses/by/4.0/)

  • the basic tests are passing

    • unit tests pass (to run, use: ./runTests.py test/unit)
    • header checks pass, (make test-headers)
    • dependencies checks pass, (make test-math-dependencies)
    • docs build, (make doxygen)
    • code passes the built in C++ standards checks (make cpplint)
  • the code is written in idiomatic C++ and changes are documented in the doxygen

  • the new changes are tested

@nhuurre nhuurre mentioned this pull request Sep 22, 2020
5 tasks
@stan-buildbot
Copy link
Contributor


Name Old Result New Result Ratio Performance change( 1 - new / old )
gp_pois_regr/gp_pois_regr.stan 4.19 4.27 0.98 -1.8% slower
low_dim_corr_gauss/low_dim_corr_gauss.stan 0.02 0.02 0.99 -0.98% slower
eight_schools/eight_schools.stan 0.09 0.09 0.98 -2.32% slower
gp_regr/gp_regr.stan 0.18 0.18 1.0 0.21% faster
irt_2pl/irt_2pl.stan 6.55 6.56 1.0 -0.24% slower
performance.compilation 90.07 87.57 1.03 2.77% faster
low_dim_gauss_mix_collapse/low_dim_gauss_mix_collapse.stan 8.27 8.38 0.99 -1.31% slower
pkpd/one_comp_mm_elim_abs.stan 29.36 33.71 0.87 -14.81% slower
sir/sir.stan 129.08 133.87 0.96 -3.71% slower
gp_regr/gen_gp_data.stan 0.05 0.05 1.02 2.01% faster
low_dim_gauss_mix/low_dim_gauss_mix.stan 3.3 3.3 1.0 0.0% slower
pkpd/sim_one_comp_mm_elim_abs.stan 0.38 0.49 0.77 -30.24% slower
arK/arK.stan 2.55 2.56 1.0 -0.34% slower
arma/arma.stan 0.73 0.73 1.0 -0.13% slower
garch/garch.stan 0.72 0.73 0.99 -0.57% slower
Mean result: 0.971901268816

Jenkins Console Log
Blue Ocean
Commit hash: c0eb7cd


Machine information ProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010

CPU:
Intel(R) Xeon(R) CPU E5-1680 v2 @ 3.00GHz

G++:
Configured with: --prefix=/Applications/Xcode.app/Contents/Developer/usr --with-gxx-include-dir=/usr/include/c++/4.2.1
Apple LLVM version 7.0.2 (clang-700.1.81)
Target: x86_64-apple-darwin15.6.0
Thread model: posix

Clang:
Apple LLVM version 7.0.2 (clang-700.1.81)
Target: x86_64-apple-darwin15.6.0
Thread model: posix

@bbbales2
Copy link
Member

bbbales2 commented Oct 4, 2020

Thanks for doing this. Apologies for the radio silence. I'll look through this in the not-infinite future. Probably after the feature freeze? This will be really good to have.

Has there been any discussion at the language level about how to do the functors? (and are there any big design decisions away from what you've coded here and we talked about previously?)

Copy link
Member

@bbbales2 bbbales2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finally got around to looking at this! Apologies for the delayyy. Left a question about the todo list.

Now that you wrote it, do you think this design makes sense? I'm trying to figure out drawbacks and stuff like that.

The test at line 268 of test/unit/math/rev/functor/ode_rk45_rev_test.cpp is the thing that brings this all together: https://github.com/stan-dev/math/pull/2094/files#diff-818d2074acd865781bdb17e5b65c089d59c893e0c248853b14aab511145d778cR268

Edit: whoops, that last sentence ("The test at line...") you can ignore.

inline void check_finite(const char* function, const char* name, const T_y& y) {
if (check_finite_screen(y)) {
auto is_good = [](const auto& y) { return std::isfinite(y); };
elementwise_check(is_good, function, name, y, ", but must be finite!");
}
}

template <typename T, require_stan_closure_t<T>* = nullptr>
inline void check_finite(const char* function, const char* name, const T& y) {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess a closure has variables, so we'd want to check that the values of the variables in the closure are finite

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I guess. I was mostly fighting the compiler here.
I'm not sure why the ODEs try and check the inputs are finite. Infinite inputs don't necessarily cause infinite outputs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I would like to get rid of these checks too.


template <typename F, typename T>
auto from_lambda(F f, T a) {
return simple_closure<F, T>(f, a);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be the key to how all this works. This accepts a lambda and an argument. This binds a to the first argument of the function f, and this object is treated as a closure (which can be passed around wherever).

THe closure itself, because it can contain vars, is treated like one in all the ODE code (so there are specializations for save_varis, deep_copy_vars, etc.).

This sound right?

Presumably we'd need to expand from_lamba to take a variable list of arguments and modify all the other higher order functions to work with this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right.

Actually, from_lambda isn't necessary because my plan was to leave implementing closure structs to the stanc codegen. require_stan_closure_t<...> recognizes closure simply by the presence of ::captured_scalar_t__; they are effectively duck-typed. simple_closure is one example of how to make one.
Of course, an alternative (complementary?) path forward is expanding simple_closure so that the template type T is a parameter pack.

Other higher-order functions just need support for variadic arguments. Shortly after submitting this PR I realized you could revert all the changes to ODEs and instead have an adapter like this

struct closure_adapter_ode {
  template<typename F, typename T0, typename T1, typename... Args>
  auto operator()(const T0& t, const Eigen::Matrix<T1, Eigen::Dynamic, 1>& y,
                  std::ostream* msgs, F& f, Args... args) {
    return f(msgs, t, y, args...);
  }
}

template <typename F, typename T_y0, typename T_t0, typename T_ts,
          typename... Args, require_stan_closure_t<F>* = nullptr>
std::vector<Eigen::Matrix<stan::return_type_t<T_y0, T_t0, T_ts, Args...>,
                          Eigen::Dynamic, 1>>
ode_rk45_tol_impl(const char* function_name, const F& f,
                  const Eigen::Matrix<T_y0, Eigen::Dynamic, 1>& y0_arg, T_t0 t0,
                  const std::vector<T_ts>& ts, double relative_tolerance,
                  double absolute_tolerance,
                  long int max_num_steps,
                  std::ostream* msgs, const Args&... args) {
  closure_adapter_ode adapter;
  return ode_rk45_tol_impl(function_name, adapter, y0_arg, t0, ts,
                           relative_tolerance, absolute_tolerance,
                           max_num_steps, msgs, f, args...);
}

Same goes for reduce_sum and others.

std::vector<var> a1 = {0.75};

auto f = stan::math::from_lambda(
[&](const auto& a, const auto& t, const auto& y, const auto& b,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this have no capture? If it captures by reference it might absorb something we don't want.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's me not being very familiar with C++. I had assumed the capture default was just part of lambda syntax.

@bbbales2
Copy link
Member

Cool beans.

I guess we have a few options from here:

  1. Survey the other higher order functions and think about what it will take to work with them (reduce_sum, map_rect, integrate_1d, and the algebra solver)

  2. Do something at the compiler level and try to get an ODE thing working end-to-end

  3. Develop the C++ interface more as if we are doing this for C++ lambdas

  4. Get a more in-depth critique of the current code (maybe grab some other people too)

Which do you think is best? Or do you have some other idea? Feel free to allocate work to me.

@nhuurre
Copy link
Collaborator Author

nhuurre commented Nov 15, 2020

I can make something that works end-to-end. (The previous prototype was working end-to-end, though the stanc3 PR is way outdated by now.)

You could help figuring out why this reduce_sum test doesn't compile

struct closure_adapter {
  template<typename F, typename T_slice, typename... Args>
  auto operator()(const T_slice& subslice, std::size_t start,
                  std::size_t end, std::ostream* msgs,
                  const F& f, Args... args) {
    return f(msgs, subslice, start, end, args...);
  }
};

TEST(StanMathRev_reduce_sum, grouped_gradient_closure) {
  using stan::math::var;
  using stan::math::from_lambda;
  using stan::math::test::get_new_msg;

  double lambda_d = 10.0;
  const std::size_t groups = 10;
  const std::size_t elems_per_group = 1000;
  const std::size_t elems = groups * elems_per_group;

  std::vector<int> data(elems);
  std::vector<int> gidx(elems);

  for (std::size_t i = 0; i != elems; ++i) {
    data[i] = i;
    gidx[i] = i / elems_per_group;
  }

  std::vector<var> vlambda_v;

  for (std::size_t i = 0; i != groups; ++i)
    vlambda_v.push_back(i + 0.2);

  var lambda_v = vlambda_v[0];

  auto functor = from_lambda(
    [](auto& lambda, auto& slice, std::size_t start, std::size_t end, auto& gidx, std::ostream * msgs) {
      const std::size_t num_terms = end - start + 1;
      std::decay_t<decltype(lambda)> lambda_slice(num_terms);
      for (std::size_t i = 0; i != num_terms; ++i)
        lambda_slice[i] = lambda[gidx[start + i]];
      return stan::math::poisson_lpmf(slice, lambda_slice);
    }, vlambda_v);

  var poisson_lpdf = stan::math::reduce_sum<closure_adapter>(
      data, 5, get_new_msg(), functor, gidx);

  std::vector<var> vref_lambda_v;
  for (std::size_t i = 0; i != elems; ++i) {
    vref_lambda_v.push_back(vlambda_v[gidx[i]]);
  }
  var lambda_ref = vlambda_v[0];
  var poisson_lpdf_ref = stan::math::poisson_lpmf(data, vref_lambda_v);

  EXPECT_FLOAT_EQ(value_of(poisson_lpdf), value_of(poisson_lpdf_ref));

  stan::math::grad(poisson_lpdf_ref.vi_);
  const double lambda_ref_adj = lambda_ref.adj();

  stan::math::set_zero_all_adjoints();
  stan::math::grad(poisson_lpdf.vi_);
  const double lambda_adj = lambda_v.adj();

  EXPECT_FLOAT_EQ(lambda_adj, lambda_ref_adj)
      << "ref value of poisson lpdf : " << poisson_lpdf_ref.val() << std::endl
      << "ref gradient wrt to lambda: " << lambda_ref_adj << std::endl
      << "value of poisson lpdf : " << poisson_lpdf.val() << std::endl
      << "gradient wrt to lambda: " << lambda_adj << std::endl;

  var poisson_lpdf_static
      = stan::math::reduce_sum_static<closure_adapter>(
          data, 5, get_new_msg(), functor, gidx);

  stan::math::set_zero_all_adjoints();
  stan::math::grad(poisson_lpdf_static.vi_);
  const double lambda_adj_static = lambda_v.adj();
  EXPECT_FLOAT_EQ(lambda_adj_static, lambda_ref_adj);
  stan::math::recover_memory();
}

The input types are vector<int>, simple_closure<var>, vector<int> and somehow it deduces return type as double when it should be var. The return_type_t<Args> does not seem to find ::captured_scalar_t__ and assumes the closure has no vars in it?
That doesn't make sense to me because the higher_order test in ode_rk45_rev_test.cpp does work despite also having to pull the return type out of a closure.

@bbbales2
Copy link
Member

@nhuurre sounds good, I'll look at the reduce_sum thing tomorrow.

I 'spose it's time to do the variadic thing for integrate_1d and the algebra solver as well.

@bbbales2
Copy link
Member

The return_type_t does not seem to find ::captured_scalar_t__ and assumes the closure has no vars in it?

Looking at the ODE example how does return_type_t know about ::captured_scalar_t__? Is this coded into scalar_type_t somewhere?

@nhuurre
Copy link
Collaborator Author

nhuurre commented Nov 16, 2020

Yes, scalar_type is supposed to be defined is_stan_closure.hpp, although now I wonder if the ode is using fn_return_type.

@bbbales2
Copy link
Member

You okay with me pushing code into this branch? I can do it separate with pull reqs if you want.

I made some changes to reduce_sum. It wasn't too bad to code this up, though the address sanitizer is telling me there is a memory leak somewhere still.

Before I went further I wanted to stop and ask about the msgs argument though. It is quite annoying having it in all the different places. If we move it we have to bump the math versions, but what do you think about always making it the first argument?

@nhuurre
Copy link
Collaborator Author

nhuurre commented Nov 16, 2020

Pushing here is fine.
msgs is indeed annoying. I think a reasonable rule is that closure objects take msgs as the first argument and everything else falls back to backwards-compatible calling convention (ie. msgs last or right before variadic arguments).

@bbbales2
Copy link
Member

back to backwards-compatible calling convention (ie. msgs last or right before variadic arguments)

Alright I added a reduce_sum_closure_adapter and things seem to pass the tests now (though the address sanitizer says I'm leaking memory somewhere and I have no reason to doubt it lol).

@nhuurre
Copy link
Collaborator Author

nhuurre commented Feb 22, 2021

I was too lazy to resolve the merge conflicts so I opened a new PR with less code: #2384

@nhuurre nhuurre closed this Feb 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants