-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core][1/N] Support PP PyNCCL Groups #4988
Conversation
Signed-off-by: Muralidhar Andoorveedu <[email protected]>
Signed-off-by: Muralidhar Andoorveedu <[email protected]>
Signed-off-by: Muralidhar Andoorveedu <[email protected]>
@youkaichao @simon-mo Can you take a quick look at this PR? |
tests/distributed/test_pynccl.py
Outdated
@@ -151,6 +152,68 @@ def test_pynccl_with_cudagraph(): | |||
distributed_run(worker_fn_with_cudagraph, 2) | |||
|
|||
|
|||
@worker_fn_wrapper | |||
def pp_worker_fn(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name it as send_recv
worker? it is not testing pp though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok will change - this is following convention from here:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can also change that into allreduce worker.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
# Hardcoded to send to the next rank for now for simplicity. | ||
dst = (self.rank + 1) % self.world_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why hardcode here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's hardcoded as we only need to send to the next and prev rank for now so it's simple. We can change this if necessary though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please don't hardcode here. just expose the general form of send/recv
. You can add default arg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed it to expose general form with default arg None
, which will make it next and prev rank respectively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know if that's sufficient
# Hardcoded to receive from the previous rank for now for simplicity. | ||
src = (self.rank - 1) % self.world_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
vllm/distributed/parallel_state.py
Outdated
_PP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( | ||
group=_PP_CPU_GROUP, | ||
device=_LOCAL_RANK, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
creating communicators will add memory cost. please add a check here, if pp size == 1, skip the construction of PyNcclCommunicator
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed both PP and TP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@youkaichao Do you know how much extra memory the communicators would be adding? It is a fixed cost or scales with group size?
We are running into some OOM with cudagraph enabled that shouldn't be happening. Any pointers would be greatly appreciated!
Signed-off-by: Muralidhar Andoorveedu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me overall. thanks! let's see if it passes the tests.
Hey @youkaichao looks like it passed, can this be merged? |
Signed-off-by: Muralidhar Andoorveedu <[email protected]>
Signed-off-by: Muralidhar Andoorveedu <[email protected]>
Signed-off-by: Muralidhar Andoorveedu <[email protected]>
Signed-off-by: Muralidhar Andoorveedu <[email protected]>
Signed-off-by: Muralidhar Andoorveedu <[email protected]>
This PR adds supports multiple PP groups. Will help #4412 support CUDAGraph.