Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMD (MI250X) support #1775

Merged
merged 7 commits into from
Oct 10, 2024
Merged

AMD (MI250X) support #1775

merged 7 commits into from
Oct 10, 2024

Conversation

TensorTemplar
Copy link
Contributor

@TensorTemplar TensorTemplar commented Oct 5, 2024

PR adds AMD support via the following:

  • split nvlink check logic to branch on device name, as reported by torch - it looks for amd in the device name and uses rocm-smi --showtopotype instead of the nvidia-smi one, for AMD.
  • rewrite the build_rope_cache to use a vectorized approach instead of indexing into the nonzero mask - I did not bench performance vs. main, but should be quicker as well as a bonus. The previous code did not work on my hardware, probably due to device detection issues, and resulted in build_rope_cache receiving device=None which then tries the nonzero indexing on the meta device and fails because of missing custom kernels 🤷‍♂️
  • added debug prints in build_rope_cache, since there was no logger (will remove once reviewed, if you have preferences on how to log, let me know)
  • reformatted the edited files via isort and black - let me know if you prefer the unlinted versions. Alternatively we could add the optional pre-commit in a separate PR

tests results on my machine (linux with a 4090):

1337 passed, 40 skipped, 70 xfailed, 2 xpassed, 381 warnings in 430.23s (0:07:10) ========================================================

Testing finetune_lora with llama 3.1 8B on a 8xMI250X node and seems to run so far with ~65-75% utilization. I am frankly quite shocked this runs quicker than our accelerate FSDP out of the box :)
Screenshot from 2024-10-05 18-08-42

@TensorTemplar TensorTemplar changed the title AMD (Mi250X) support AMD (MI250X) support Oct 5, 2024
@TensorTemplar TensorTemplar marked this pull request as draft October 5, 2024 15:27
@TensorTemplar TensorTemplar force-pushed the mi250x-support branch 2 times, most recently from d8e45cf to 03b33db Compare October 5, 2024 16:11
@TensorTemplar TensorTemplar marked this pull request as ready for review October 5, 2024 17:06
@rasbt
Copy link
Collaborator

rasbt commented Oct 7, 2024

Thanks for the PR. Just reading through your commens, the premise is great actually. Before merging, we need to check (with some concrete numbers) that this doesn't impact CUDA performance though, but it sounds like it may actually even be a positive impact here 😊.

reformatted the edited files via isort and black - let me know if you prefer the unlinted versions. Alternatively we could add the optional pre-commit in a separate PR

Yes, could you please revert those style edits back? It makes reviewing code a bit harder because there are so many additional changes. I also had many discussions with colleagues on that internally, and the general preference seems to be not to use a linter like black.

@TensorTemplar
Copy link
Contributor Author

Thanks for the PR. Just reading through your commens, the premise is great actually. Before merging, we need to check (with some concrete numbers) that this doesn't impact CUDA performance though, but it sounds like it may actually even be a positive impact here 😊.

reformatted the edited files via isort and black - let me know if you prefer the unlinted versions. Alternatively we could add the optional pre-commit in a separate PR

Yes, could you please revert those style edits back? It makes reviewing code a bit harder because there are so many additional changes. I also had many discussions with colleagues on that internally, and the general preference seems to be not to use a linter like black.

Sounds good, i will reproduce some of the config hub fine-tunes on CUDA and A6000s, but don't have any hardware running that would be a good test for multi-GPU and multi-node FSDP at the moment.

@rasbt
Copy link
Collaborator

rasbt commented Oct 8, 2024

Thanks! And no worries, I can help with multi-GPU comparisons on CUDA

@TensorTemplar
Copy link
Contributor Author

TensorTemplar commented Oct 9, 2024

Rebased and reverted the import style back to single-line, as discussed.

Please have a look at the perf benchmarks below:

NVIDIA GeForce RTX 4090
Driver Version: 560.35.03
CUDA Version: 12.6

with config_hub/finetune/llama-3.2-3B/lora.yaml
This branch:
Training time: 262.61s
Memory used: 9.67 GB
Validating ...
Final evaluation | val loss: 0.994 | val ppl: 2.701

0.5.0
Training time: 264.34s
Memory used: 9.67 GB
Validating ...
Final evaluation | val loss: 0.993 | val ppl: 2.700

with config_hub/finetune/llama-3.1-8b/lora.yaml
This branch:
Training time: 349.64s
Memory used: 19.73 GB
Validating ...
Final evaluation | val loss: 0.879 | val ppl: 2.408

0.5.0
Training time: 377.03s
Memory used: 19.73 GB
Validating ...
Final evaluation | val loss: 0.879 | val ppl: 2.409

New SOTA clearly established 😆

@rasbt
Copy link
Collaborator

rasbt commented Oct 9, 2024

Nice, thanks for the numbers! I will also do a run on multi-GPU to confirm, but this looks awesome!

frankly quite shocked this runs quicker than our accelerate FSDP

May I ask what the accelerated FSDP method is? Is this vanilla FSDP from PyTorch or some other method? Just curious if there's maybe some trick that we can add here to make it even faster.

@TensorTemplar
Copy link
Contributor Author

TensorTemplar commented Oct 9, 2024

Nice, thanks for the numbers! I will also do a run on multi-GPU to confirm, but this looks awesome!

frankly quite shocked this runs quicker than our accelerate FSDP

May I ask what the accelerated FSDP method is? Is this vanilla FSDP from PyTorch or some other method? Just curious if there's maybe some trick that we can add here to make it even faster.

is FSDP ever vanilla ? 😅 I cannot share the accelerate configs nor more details on that unfortunately, but i will share some stats from llama 70B and 405B fine-tunes, if i get them to run with litgpt on larger datasets.
I am not a big fan of HF/Accelerate/TRL UX and code quality, so will be more than happy to port everything to fabric now ^~

…clarification comment in test

Revert failover to cpu in build_rope_cache when device is None

Add test fixture for amd multigpu xgmi and nvidia dualgpu nvlink. Update tests to use fixtures

Use fixtures for device properties. Follow existing style in fixture order. Mock subprocess.run for new tests

Use real device names in mocks

Remove redundant mocks

Remove warning print, revert import sorting to previous style

Remove warning print, revert import sorting to previous style
@rasbt
Copy link
Collaborator

rasbt commented Oct 9, 2024

@TensorTemplar I reverted all the style changes in the non-relevant codes and just see your push overrode these again 😅. Could you please roll those back to make the reviewing easier and focus the PR just on the RoPE improvemenets and NVLink/AMD tests?

@TensorTemplar
Copy link
Contributor Author

TensorTemplar commented Oct 9, 2024

@TensorTemplar I reverted all the style changes in the non-relevant codes and just see your push overrode these again 😅. Could you please roll those back to make the reviewing easier and focus the PR just on the RoPE improvemenets and NVLink/AMD tests?

Sorry, thought i squash some of the commits and merge new tests and overrode without checking that you committed into my branch. Style changes now reverted.

@rasbt
Copy link
Collaborator

rasbt commented Oct 9, 2024

Sorry, thought i squash some of the commits and merge new tests and overrode without checking that you committed into my branch. Style changes now reverted.

No worries, and thanks for updating that!

Actually, I was trying to run the code before and after your PR on an 8xA100 machine and noticed issues with the code before your PR:

adjusted_theta[mask_low_freq] = theta[mask_low_freq] / factor
  File "/home/sebastian/miniforge3/envs/litgp2/lib/python3.10/site-packages/torch/_meta_registrations.py", line 2883, in meta_index_Tensor
    nonzero = index.nonzero()
  File "/home/sebastian/miniforge3/envs/litgp2/lib/python3.10/site-packages/torch/utils/_device.py", line 77, in __torch_function__
    return func(*args, **kwargs)
NotImplementedError: aten::nonzero: attempted to run this operator with Meta tensors, but there was no abstract impl or Meta kernel registered. 

I must admit that I didn't run the updated RoPE code in multi-GPU settings, only single GPU training. So yes, we should definitely merge your PR 😊

@TensorTemplar
Copy link
Contributor Author

TensorTemplar commented Oct 9, 2024

Sorry, thought i squash some of the commits and merge new tests and overrode without checking that you committed into my branch. Style changes now reverted.

No worries, and thanks for updating that!

Actually, I was trying to run the code before and after your PR on an 8xA100 machine and noticed issues with the code before your PR:

adjusted_theta[mask_low_freq] = theta[mask_low_freq] / factor
  File "/home/sebastian/miniforge3/envs/litgp2/lib/python3.10/site-packages/torch/_meta_registrations.py", line 2883, in meta_index_Tensor
    nonzero = index.nonzero()
  File "/home/sebastian/miniforge3/envs/litgp2/lib/python3.10/site-packages/torch/utils/_device.py", line 77, in __torch_function__
    return func(*args, **kwargs)
NotImplementedError: aten::nonzero: attempted to run this operator with Meta tensors, but there was no abstract impl or Meta kernel registered. 

I must admit that I didn't run the updated RoPE code in multi-GPU settings, only single GPU training. So yes, we should definitely merge your PR 😊

Oh i speculated this was an issue with missing custom cuda kernels on AMD, since i didn't see this on my local CUDA test machines.

Copy link
Collaborator

@rasbt rasbt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks all good great to me now! Thanks so much again for the PR!

@rasbt rasbt merged commit 46c4337 into Lightning-AI:main Oct 10, 2024
8 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants