Skip to content

Commit

Permalink
[core][distributed] initialization from StatelessProcessGroup (#10986)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
  • Loading branch information
youkaichao authored and Akshat-Tripathi committed Dec 12, 2024
1 parent 9fabf20 commit 194e7c5
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 69 deletions.
6 changes: 3 additions & 3 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,11 @@ steps:
- tests/distributed/
commands:
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed'
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed'
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'

- label: Distributed Tests (2 GPUs) # 40min
#mirror_hardwares: [amd]
Expand All @@ -455,7 +455,7 @@ steps:
commands:
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
# Avoid importing model tests that cause CUDA reinitialization error
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
Expand Down
29 changes: 25 additions & 4 deletions tests/distributed/test_same_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,32 @@
import torch.distributed as dist

from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import get_ip, get_open_port

if __name__ == "__main__":
dist.init_process_group(backend="gloo")
test_result = all(in_the_same_node_as(dist.group.WORLD, source_rank=0))

expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
assert test_result == expected, f"Expected {expected}, got {test_result}"
print("Same node test passed!")
rank = dist.get_rank()
if rank == 0:
port = get_open_port()
ip = get_ip()
dist.broadcast_object_list([ip, port], src=0)
else:
recv = [None, None]
dist.broadcast_object_list(recv, src=0)
ip, port = recv

stateless_pg = StatelessProcessGroup.create(ip, port, rank,
dist.get_world_size())

for pg in [dist.group.WORLD, stateless_pg]:
test_result = all(in_the_same_node_as(pg, source_rank=0))

expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
assert test_result == expected, \
f"Expected {expected}, got {test_result}"
if pg == dist.group.WORLD:
print("Same node test passed! when using torch distributed!")
else:
print("Same node test passed! when using StatelessProcessGroup!")
84 changes: 56 additions & 28 deletions tests/distributed/test_shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch.distributed as dist

from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.utils import update_environment_variables
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import get_ip, get_open_port, update_environment_variables


def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
Expand Down Expand Up @@ -54,34 +55,61 @@ def wrapped_fn(env):

@worker_fn_wrapper
def worker_fn():
writer_rank = 2
broadcaster = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank)
if dist.get_rank() == writer_rank:
seed = random.randint(0, 1000)
dist.broadcast_object_list([seed], writer_rank)
else:
recv = [None]
dist.broadcast_object_list(recv, writer_rank)
seed = recv[0] # type: ignore
dist.barrier()
# in case we find a race condition
# print the seed so that we can reproduce the error
print(f"Rank {dist.get_rank()} got seed {seed}")
# test broadcasting with about 400MB of data
N = 10_000
if dist.get_rank() == writer_rank:
arrs = get_arrays(N, seed)
for x in arrs:
broadcaster.broadcast_object(x)
time.sleep(random.random() / 1000)

rank = dist.get_rank()
if rank == 0:
port = get_open_port()
ip = get_ip()
dist.broadcast_object_list([ip, port], src=0)
else:
arrs = get_arrays(N, seed)
for x in arrs:
y = broadcaster.broadcast_object(None)
assert np.array_equal(x, y)
time.sleep(random.random() / 1000)
dist.barrier()
recv = [None, None]
dist.broadcast_object_list(recv, src=0)
ip, port = recv

stateless_pg = StatelessProcessGroup.create(ip, port, rank,
dist.get_world_size())

for pg in [dist.group.WORLD, stateless_pg]:

writer_rank = 2
broadcaster = MessageQueue.create_from_process_group(
pg, 40 * 1024, 2, writer_rank)
if rank == writer_rank:
seed = random.randint(0, 1000)
dist.broadcast_object_list([seed], writer_rank)
else:
recv = [None]
dist.broadcast_object_list(recv, writer_rank)
seed = recv[0] # type: ignore

if pg == dist.group.WORLD:
dist.barrier()
else:
pg.barrier()

# in case we find a race condition
# print the seed so that we can reproduce the error
print(f"Rank {rank} got seed {seed}")
# test broadcasting with about 400MB of data
N = 10_000
if rank == writer_rank:
arrs = get_arrays(N, seed)
for x in arrs:
broadcaster.broadcast_object(x)
time.sleep(random.random() / 1000)
else:
arrs = get_arrays(N, seed)
for x in arrs:
y = broadcaster.broadcast_object(None)
assert np.array_equal(x, y)
time.sleep(random.random() / 1000)

if pg == dist.group.WORLD:
dist.barrier()
print("torch distributed passed the test!")
else:
pg.barrier()
print("StatelessProcessGroup passed the test!")


def test_shm_broadcast():
Expand Down
39 changes: 26 additions & 13 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
from unittest.mock import patch

import torch
Expand All @@ -15,6 +15,7 @@
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore

import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address

Expand Down Expand Up @@ -476,13 +477,19 @@ def broadcast_object(self, obj=None):
return self.dequeue()

@staticmethod
def create_from_process_group(pg: ProcessGroup,
def create_from_process_group(pg: Union[ProcessGroup,
StatelessProcessGroup],
max_chunk_bytes,
max_chunks,
writer_rank=0) -> "MessageQueue":
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
global_ranks = dist.get_process_group_ranks(pg)
if isinstance(pg, ProcessGroup):
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
global_ranks = dist.get_process_group_ranks(pg)
else:
group_rank = pg.rank
group_world_size = pg.world_size
global_ranks = list(range(pg.world_size))

from vllm.distributed.parallel_state import in_the_same_node_as
status = in_the_same_node_as(pg, source_rank=writer_rank)
Expand All @@ -500,15 +507,21 @@ def create_from_process_group(pg: ProcessGroup,
max_chunks=max_chunks,
)
handle = buffer_io.export_handle()
dist.broadcast_object_list([handle],
src=global_ranks[writer_rank],
group=pg)
if isinstance(pg, ProcessGroup):
dist.broadcast_object_list([handle],
src=global_ranks[writer_rank],
group=pg)
else:
pg.broadcast_obj(handle, writer_rank)
else:
recv = [None]
dist.broadcast_object_list(recv,
src=global_ranks[writer_rank],
group=pg)
handle = recv[0] # type: ignore
if isinstance(pg, ProcessGroup):
recv = [None]
dist.broadcast_object_list(recv,
src=global_ranks[writer_rank],
group=pg)
handle = recv[0] # type: ignore
else:
handle = pg.broadcast_obj(None, writer_rank)
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
buffer_io.wait_until_ready()
return buffer_io
64 changes: 43 additions & 21 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, supports_custom_op
Expand Down Expand Up @@ -1191,25 +1192,31 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
torch.cuda.empty_cache()


def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
source_rank: int = 0) -> List[bool]:
"""
This is a collective operation that returns if each rank is in the same node
as the source rank. It tests if processes are attached to the same
memory system (shared access to shared memory).
"""
assert torch.distributed.get_backend(
pg) != torch.distributed.Backend.NCCL, (
"in_the_same_node_as should be tested with a non-NCCL group.")
# local rank inside the group
rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg)
if isinstance(pg, ProcessGroup):
assert torch.distributed.get_backend(
pg) != torch.distributed.Backend.NCCL, (
"in_the_same_node_as should be tested with a non-NCCL group.")
# local rank inside the group
rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg)

# global ranks of the processes in the group
ranks = torch.distributed.get_process_group_ranks(pg)
else:
rank = pg.rank
world_size = pg.world_size
ranks = list(range(world_size))

# local tensor in each process to store the result
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

# global ranks of the processes in the group
ranks = torch.distributed.get_process_group_ranks(pg)

magic_message = b"magic_message"
shm = None

Expand All @@ -1219,17 +1226,21 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
# create a shared memory segment
shm = shared_memory.SharedMemory(create=True, size=128)
shm.buf[:len(magic_message)] = magic_message
torch.distributed.broadcast_object_list([shm.name],
src=ranks[source_rank],
group=pg)
if isinstance(pg, ProcessGroup):
torch.distributed.broadcast_object_list(
[shm.name], src=ranks[source_rank], group=pg)
else:
pg.broadcast_obj(shm.name, src=source_rank)
is_in_the_same_node[rank] = 1
else:
# try to open the shared memory segment
recv = [None]
torch.distributed.broadcast_object_list(recv,
src=ranks[source_rank],
group=pg)
name = recv[0]
if isinstance(pg, ProcessGroup):
recv = [None]
torch.distributed.broadcast_object_list(
recv, src=ranks[source_rank], group=pg)
name = recv[0]
else:
name = pg.broadcast_obj(None, src=source_rank)
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
Expand All @@ -1244,12 +1255,23 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
if shm:
shm.close()

torch.distributed.barrier(group=pg)
if isinstance(pg, ProcessGroup):
torch.distributed.barrier(group=pg)
else:
pg.barrier()

# clean up the shared memory segment
with contextlib.suppress(OSError):
if rank == source_rank and shm:
shm.unlink()
torch.distributed.all_reduce(is_in_the_same_node, group=pg)

return [x == 1 for x in is_in_the_same_node.tolist()]
if isinstance(pg, ProcessGroup):
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
aggregated_data = is_in_the_same_node
else:
aggregated_data = torch.zeros_like(is_in_the_same_node)
for i in range(world_size):
rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
aggregated_data += rank_data

return [x == 1 for x in aggregated_data.tolist()]

0 comments on commit 194e7c5

Please sign in to comment.