Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(train) Add support for torch.compile (EXPERIMENTAL) #2931

Closed
wants to merge 58 commits into from

Conversation

canergen
Copy link
Member

@canergen canergen commented Aug 7, 2024

No description provided.

@canergen canergen added the cuda tests Run test suite on CUDA label Aug 7, 2024
@canergen
Copy link
Member Author

canergen commented Aug 7, 2024

@ori-kron-wis Can you add tests for all pytorch models (not pyro and not jax) for compile. Can you check speed improvements on your end? You should execute it with: model.train(accelerator='cuda', plan_kwargs={'n_epochs_kl_warmup': 100, 'compile': True}, datasplitter_kwargs={'drop_last': True})

trainingplans with future imports and test_compute_kl revert
@ori-kron-wis ori-kron-wis self-assigned this Aug 26, 2024
@ori-kron-wis ori-kron-wis self-requested a review August 26, 2024 14:49
@canergen
Copy link
Member Author

Needs tests like: model2.train(accelerator='cuda', batch_size=5000, max_epochs=100, train_size=0.9, plan_kwargs={'n_epochs_kl_warmup': 100, 'compile': True}, datasplitter_kwargs={'drop_last': True}) and then get_elbo, get_reconstruction_loss, get_latent.

@ori-kron-wis
Copy link
Collaborator

ori-kron-wis commented Sep 17, 2024

I added torch compile tests for most models (of course not working with the github action due to that error) - on new servers, it worked fine and was faster, although the compile part will add some overhead.

Currently pyro test not working on a multi GPU machine. Need to see why (only test_pyro_bayesian_regression). once we remove it everything works (it should be passed here)

@ori-kron-wis
Copy link
Collaborator

This branch, not included its tests, is merged to main together with MPS fix in : #3100

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cuda optional tests cuda tests Run test suite on CUDA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants