Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable JAX backend #812

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

Enable JAX backend #812

wants to merge 5 commits into from

Conversation

timmens
Copy link
Collaborator

@timmens timmens commented Jan 21, 2025

Caution

This PR is only open to showcase the required changes to enable JAX. When #804 is merged, we will start implementing the changes --that might be here or in a fresh PR.

Enable JAX backend

In this PR, we ensure that gettsim becomes fully operational with the jax backend and is tested accordingly such that future changes in the codebase that are not JAX-compatible trigger test failures.

Closes #515

Issues

To make GETTSIM fully JAX-operable, the tax and transfers function, returned by dags, needs to JAX-jittable. This requires all functions defined in GETTSIM and called internally to be JAX-jittable. In particular, they cannot use functions from numpy or array methods that only work on numpy arrays.

ToDo's

  • Make sure that _vectorize_func() is actually used and at correct position in code
  • Ensure that make_vectorizable can handle lambda functions

Testing

  • JAX end-to-end tests (i.e., jax.jit the compute taxes and transfers functions)
  • JAX unit tests (i.e., jax.jit individual functions and call them on JAX arrays)
  • Run (some of them) JAX tests on GHA

Numpy replacement

  • Replace numpy.function(...) by numpy_or_jax.function(...)
    • asarray
    • allclose
    • searchsorted
    • array
    • zeros
    • inf
    • full_like
    • min
    • max
  • Write JAX compatible join
    Currently, join_numpy is implemented. Either (1) implement join_jax and call if required, or (2) rewrite join_numpy using Array API standard.

Datetime replacement/refactoring

  • Make sure that numpy.datetime64 is used before running dags
    (We want to allow the user to pass in data frames with DateTime-objects. However, these should be transformed internally to day, month, and year ints.)

Piecewise Functions

  • Rewrite piecewise functions code
    Either create (1) one version of the code for numpy and one of JAX, or (2) try to solve this problem using the Array API standard.

Copy link

codecov bot commented Jan 21, 2025

Codecov Report

Attention: Patch coverage is 66.66667% with 1 line in your changes missing coverage. Please review.

Project coverage is 87.08%. Comparing base (33ef921) to head (6d19b9d).

Files with missing lines Patch % Lines
src/_gettsim/policy_environment_postprocessor.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #812      +/-   ##
==========================================
- Coverage   87.82%   87.08%   -0.75%     
==========================================
  Files          56       56              
  Lines        3976     3971       -5     
==========================================
- Hits         3492     3458      -34     
- Misses        484      513      +29     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: Enable vectorizaton / Jax
1 participant