-
Notifications
You must be signed in to change notification settings - Fork 374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(train) Add support for torch.compile (EXPERIMENTAL) #2931
Conversation
@ori-kron-wis Can you add tests for all pytorch models (not pyro and not jax) for compile. Can you check speed improvements on your end? You should execute it with: |
trainingplans with future imports and test_compute_kl revert
Needs tests like: |
I added torch compile tests for most models (of course not working with the github action due to that error) - on new servers, it worked fine and was faster, although the compile part will add some overhead. Currently pyro test not working on a multi GPU machine. Need to see why (only test_pyro_bayesian_regression). once we remove it everything works (it should be passed here) |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
This branch, not included its tests, is merged to main together with MPS fix in : #3100 |
No description provided.