From 194e7c5c3b3de0cff7f798d6bb8f517f9f1b9546 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 01:04:19 -0800 Subject: [PATCH] [core][distributed] initialization from StatelessProcessGroup (#10986) Signed-off-by: youkaichao Signed-off-by: Akshat Tripathi --- .buildkite/test-pipeline.yaml | 6 +- tests/distributed/test_same_node.py | 29 ++++++- tests/distributed/test_shm_broadcast.py | 84 ++++++++++++------- .../device_communicators/shm_broadcast.py | 39 ++++++--- vllm/distributed/parallel_state.py | 64 +++++++++----- 5 files changed, 153 insertions(+), 69 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index aca505178df06..6a6ee3cf713ae 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -432,11 +432,11 @@ steps: - tests/distributed/ commands: - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' - label: Distributed Tests (2 GPUs) # 40min #mirror_hardwares: [amd] @@ -455,7 +455,7 @@ steps: commands: - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' # Avoid importing model tests that cause CUDA reinitialization error - pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)' diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py index defc4e23c8ce2..62311a626bc47 100644 --- a/tests/distributed/test_same_node.py +++ b/tests/distributed/test_same_node.py @@ -3,11 +3,32 @@ import torch.distributed as dist from vllm.distributed.parallel_state import in_the_same_node_as +from vllm.distributed.utils import StatelessProcessGroup +from vllm.utils import get_ip, get_open_port if __name__ == "__main__": dist.init_process_group(backend="gloo") - test_result = all(in_the_same_node_as(dist.group.WORLD, source_rank=0)) - expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" - assert test_result == expected, f"Expected {expected}, got {test_result}" - print("Same node test passed!") + rank = dist.get_rank() + if rank == 0: + port = get_open_port() + ip = get_ip() + dist.broadcast_object_list([ip, port], src=0) + else: + recv = [None, None] + dist.broadcast_object_list(recv, src=0) + ip, port = recv + + stateless_pg = StatelessProcessGroup.create(ip, port, rank, + dist.get_world_size()) + + for pg in [dist.group.WORLD, stateless_pg]: + test_result = all(in_the_same_node_as(pg, source_rank=0)) + + expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" + assert test_result == expected, \ + f"Expected {expected}, got {test_result}" + if pg == dist.group.WORLD: + print("Same node test passed! when using torch distributed!") + else: + print("Same node test passed! when using StatelessProcessGroup!") diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index 2761b7f6c0644..723872682cf97 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -7,7 +7,8 @@ import torch.distributed as dist from vllm.distributed.device_communicators.shm_broadcast import MessageQueue -from vllm.utils import update_environment_variables +from vllm.distributed.utils import StatelessProcessGroup +from vllm.utils import get_ip, get_open_port, update_environment_variables def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]: @@ -54,34 +55,61 @@ def wrapped_fn(env): @worker_fn_wrapper def worker_fn(): - writer_rank = 2 - broadcaster = MessageQueue.create_from_process_group( - dist.group.WORLD, 40 * 1024, 2, writer_rank) - if dist.get_rank() == writer_rank: - seed = random.randint(0, 1000) - dist.broadcast_object_list([seed], writer_rank) - else: - recv = [None] - dist.broadcast_object_list(recv, writer_rank) - seed = recv[0] # type: ignore - dist.barrier() - # in case we find a race condition - # print the seed so that we can reproduce the error - print(f"Rank {dist.get_rank()} got seed {seed}") - # test broadcasting with about 400MB of data - N = 10_000 - if dist.get_rank() == writer_rank: - arrs = get_arrays(N, seed) - for x in arrs: - broadcaster.broadcast_object(x) - time.sleep(random.random() / 1000) + + rank = dist.get_rank() + if rank == 0: + port = get_open_port() + ip = get_ip() + dist.broadcast_object_list([ip, port], src=0) else: - arrs = get_arrays(N, seed) - for x in arrs: - y = broadcaster.broadcast_object(None) - assert np.array_equal(x, y) - time.sleep(random.random() / 1000) - dist.barrier() + recv = [None, None] + dist.broadcast_object_list(recv, src=0) + ip, port = recv + + stateless_pg = StatelessProcessGroup.create(ip, port, rank, + dist.get_world_size()) + + for pg in [dist.group.WORLD, stateless_pg]: + + writer_rank = 2 + broadcaster = MessageQueue.create_from_process_group( + pg, 40 * 1024, 2, writer_rank) + if rank == writer_rank: + seed = random.randint(0, 1000) + dist.broadcast_object_list([seed], writer_rank) + else: + recv = [None] + dist.broadcast_object_list(recv, writer_rank) + seed = recv[0] # type: ignore + + if pg == dist.group.WORLD: + dist.barrier() + else: + pg.barrier() + + # in case we find a race condition + # print the seed so that we can reproduce the error + print(f"Rank {rank} got seed {seed}") + # test broadcasting with about 400MB of data + N = 10_000 + if rank == writer_rank: + arrs = get_arrays(N, seed) + for x in arrs: + broadcaster.broadcast_object(x) + time.sleep(random.random() / 1000) + else: + arrs = get_arrays(N, seed) + for x in arrs: + y = broadcaster.broadcast_object(None) + assert np.array_equal(x, y) + time.sleep(random.random() / 1000) + + if pg == dist.group.WORLD: + dist.barrier() + print("torch distributed passed the test!") + else: + pg.barrier() + print("StatelessProcessGroup passed the test!") def test_shm_broadcast(): diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 9a2d8918d96e5..9f97b0f01ad8a 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -5,7 +5,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union from unittest.mock import patch import torch @@ -15,6 +15,7 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore import vllm.envs as envs +from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address @@ -476,13 +477,19 @@ def broadcast_object(self, obj=None): return self.dequeue() @staticmethod - def create_from_process_group(pg: ProcessGroup, + def create_from_process_group(pg: Union[ProcessGroup, + StatelessProcessGroup], max_chunk_bytes, max_chunks, writer_rank=0) -> "MessageQueue": - group_rank = dist.get_rank(pg) - group_world_size = dist.get_world_size(pg) - global_ranks = dist.get_process_group_ranks(pg) + if isinstance(pg, ProcessGroup): + group_rank = dist.get_rank(pg) + group_world_size = dist.get_world_size(pg) + global_ranks = dist.get_process_group_ranks(pg) + else: + group_rank = pg.rank + group_world_size = pg.world_size + global_ranks = list(range(pg.world_size)) from vllm.distributed.parallel_state import in_the_same_node_as status = in_the_same_node_as(pg, source_rank=writer_rank) @@ -500,15 +507,21 @@ def create_from_process_group(pg: ProcessGroup, max_chunks=max_chunks, ) handle = buffer_io.export_handle() - dist.broadcast_object_list([handle], - src=global_ranks[writer_rank], - group=pg) + if isinstance(pg, ProcessGroup): + dist.broadcast_object_list([handle], + src=global_ranks[writer_rank], + group=pg) + else: + pg.broadcast_obj(handle, writer_rank) else: - recv = [None] - dist.broadcast_object_list(recv, - src=global_ranks[writer_rank], - group=pg) - handle = recv[0] # type: ignore + if isinstance(pg, ProcessGroup): + recv = [None] + dist.broadcast_object_list(recv, + src=global_ranks[writer_rank], + group=pg) + handle = recv[0] # type: ignore + else: + handle = pg.broadcast_obj(None, writer_rank) buffer_io = MessageQueue.create_from_handle(handle, group_rank) buffer_io.wait_until_ready() return buffer_io diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 34815d7f0aa78..5b9236f8c56b6 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -37,6 +37,7 @@ import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer import vllm.envs as envs +from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op, supports_custom_op @@ -1191,25 +1192,31 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): torch.cuda.empty_cache() -def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: +def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], + source_rank: int = 0) -> List[bool]: """ This is a collective operation that returns if each rank is in the same node as the source rank. It tests if processes are attached to the same memory system (shared access to shared memory). """ - assert torch.distributed.get_backend( - pg) != torch.distributed.Backend.NCCL, ( - "in_the_same_node_as should be tested with a non-NCCL group.") - # local rank inside the group - rank = torch.distributed.get_rank(group=pg) - world_size = torch.distributed.get_world_size(group=pg) + if isinstance(pg, ProcessGroup): + assert torch.distributed.get_backend( + pg) != torch.distributed.Backend.NCCL, ( + "in_the_same_node_as should be tested with a non-NCCL group.") + # local rank inside the group + rank = torch.distributed.get_rank(group=pg) + world_size = torch.distributed.get_world_size(group=pg) + + # global ranks of the processes in the group + ranks = torch.distributed.get_process_group_ranks(pg) + else: + rank = pg.rank + world_size = pg.world_size + ranks = list(range(world_size)) # local tensor in each process to store the result is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) - # global ranks of the processes in the group - ranks = torch.distributed.get_process_group_ranks(pg) - magic_message = b"magic_message" shm = None @@ -1219,17 +1226,21 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: # create a shared memory segment shm = shared_memory.SharedMemory(create=True, size=128) shm.buf[:len(magic_message)] = magic_message - torch.distributed.broadcast_object_list([shm.name], - src=ranks[source_rank], - group=pg) + if isinstance(pg, ProcessGroup): + torch.distributed.broadcast_object_list( + [shm.name], src=ranks[source_rank], group=pg) + else: + pg.broadcast_obj(shm.name, src=source_rank) is_in_the_same_node[rank] = 1 else: # try to open the shared memory segment - recv = [None] - torch.distributed.broadcast_object_list(recv, - src=ranks[source_rank], - group=pg) - name = recv[0] + if isinstance(pg, ProcessGroup): + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=ranks[source_rank], group=pg) + name = recv[0] + else: + name = pg.broadcast_obj(None, src=source_rank) # fix to https://stackoverflow.com/q/62748654/9191338 # Python incorrectly tracks shared memory even if it is not # created by the process. The following patch is a workaround. @@ -1244,12 +1255,23 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: if shm: shm.close() - torch.distributed.barrier(group=pg) + if isinstance(pg, ProcessGroup): + torch.distributed.barrier(group=pg) + else: + pg.barrier() # clean up the shared memory segment with contextlib.suppress(OSError): if rank == source_rank and shm: shm.unlink() - torch.distributed.all_reduce(is_in_the_same_node, group=pg) - return [x == 1 for x in is_in_the_same_node.tolist()] + if isinstance(pg, ProcessGroup): + torch.distributed.all_reduce(is_in_the_same_node, group=pg) + aggregated_data = is_in_the_same_node + else: + aggregated_data = torch.zeros_like(is_in_the_same_node) + for i in range(world_size): + rank_data = pg.broadcast_obj(is_in_the_same_node, src=i) + aggregated_data += rank_data + + return [x == 1 for x in aggregated_data.tolist()]