Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Caution
This PR is only open to showcase the required changes to enable JAX. When #804 is merged, we will start implementing the changes --that might be here or in a fresh PR.
Enable JAX backend
In this PR, we ensure that
gettsim
becomes fully operational with thejax
backend and is tested accordingly such that future changes in the codebase that are not JAX-compatible trigger test failures.Closes #515
Issues
To make GETTSIM fully JAX-operable, the tax and transfers function, returned by
dags
, needs to JAX-jittable. This requires all functions defined in GETTSIM and called internally to be JAX-jittable. In particular, they cannot use functions fromnumpy
or array methods that only work onnumpy
arrays.ToDo's
_vectorize_func()
is actually used and at correct position in codemake_vectorizable
can handle lambda functionsTesting
jax.jit
the compute taxes and transfers functions)jax.jit
individual functions and call them on JAX arrays)Numpy replacement
numpy.function(...)
bynumpy_or_jax.function(...)
asarray
allclose
searchsorted
array
zeros
inf
full_like
min
max
join
Currently,
join_numpy
is implemented. Either (1) implementjoin_jax
and call if required, or (2) rewritejoin_numpy
using Array API standard.Datetime replacement/refactoring
numpy.datetime64
is used before runningdags
(We want to allow the user to pass in data frames with DateTime-objects. However, these should be transformed internally to day, month, and year ints.)
Piecewise Functions
Either create (1) one version of the code for
numpy
and one of JAX, or (2) try to solve this problem using the Array API standard.