Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile to JAX to enable GPU/TPU acceleration and vmap. #209

Open
yebai opened this issue Aug 27, 2024 · 4 comments
Open

Compile to JAX to enable GPU/TPU acceleration and vmap. #209

yebai opened this issue Aug 27, 2024 · 4 comments

Comments

@yebai
Copy link
Member

yebai commented Aug 27, 2024

No description provided.

@sunxd3
Copy link
Member

sunxd3 commented Aug 28, 2024

I really want to make this work, will spend some time and try to produce a prototype soon

@yebai
Copy link
Member Author

yebai commented Aug 28, 2024

If we utilise numpyro distributions, this looks quite doable: https://num.pyro.ai/en/stable/distributions.html

@sunxd3
Copy link
Member

sunxd3 commented Aug 28, 2024

tensorflow prob and just plain jax.random are also good

@yebai
Copy link
Member Author

yebai commented Aug 28, 2024

jax.random only provides samplers for common distributions.

DeepMind's distrax reimplemented TFP in native JAX.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants