Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adapt custom allreduce for tensorrt llm #2511

Merged
merged 5 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ runtime_common = [
]
srt = [
"sglang[runtime_common]", "cuda-python",
"sgl-kernel>=0.0.2.post12", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1",
"sgl-kernel>=0.0.2.post14", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1",
"flashinfer==0.1.6"
]

Expand Down
49 changes: 22 additions & 27 deletions python/sglang/srt/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/_custom_ops.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import contextlib
import functools
import importlib
Expand All @@ -14,7 +14,7 @@

if not is_hpu():
try:
import custom_ar
import sgl_kernel
except ImportError as e:
logger.warning("Failed to import from custom_ar with %r", e)

Expand Down Expand Up @@ -50,46 +50,41 @@ def wrapper(*args, **kwargs):

# custom ar
def init_custom_ar(
ipc_tensors: List[torch.Tensor],
rank_data: torch.Tensor,
rank: int,
full_nvlink: bool,
rank_id: int,
world_size: int,
rank_data_base: torch.Tensor,
buffers: List[int],
tmp_result_buffers: List[int],
barrier_in: List[int],
barrier_out: List[int],
) -> int:
return torch.ops._C_vllm_ar.init_custom_ar(
ipc_tensors, rank_data, rank, full_nvlink
return sgl_kernel.ops.init_custom_reduce(
rank_id,
world_size,
rank_data_base,
buffers,
tmp_result_buffers,
barrier_in,
barrier_out,
)


def all_reduce(
fa: int,
inp: torch.Tensor,
out: torch.Tensor,
reg_buffer: int,
reg_buffer_sz_bytes: int,
) -> None:
torch.ops._C_vllm_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.ops.custom_reduce(fa, inp, out)


def dispose(fa: int) -> None:
torch.ops._C_vllm_ar.dispose(fa)


def meta_size() -> int:
return torch.ops._C_vllm_ar.meta_size()


def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
return torch.ops._C_vllm_ar.register_buffer(fa, ipc_tensors)
sgl_kernel.ops.custom_dispose(fa)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return torch.ops._C_vllm_ar.get_graph_buffer_ipc_meta(fa)
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)


def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
torch.ops._C_vllm_ar.register_graph_buffers(fa, handles, offsets)
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)


# temporary fix for https://github.com/vllm-project/vllm/issues/5456
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
from sglang.srt.utils import cuda_device_count_stateless, is_cuda

try:
ops.meta_size()
import sgl_kernel

custom_ar = True
except Exception:
# For AMD GPUs and CPUs
custom_ar = False

logger = logging.getLogger(__name__)


_P = ParamSpec("_P")
_R = TypeVar("_R")

Expand All @@ -47,7 +47,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:


@with_nvml_context
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
Expand Down Expand Up @@ -196,32 +196,39 @@ def __init__(
)
return

self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(
ops.meta_size() + max_size, group=group
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.full_nvlink = full_nvlink

# From TensorRT-LLM getMaxRequiredWorkspaceSize
self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]

# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self.barrier_max_size = 8 * (36 + 2) * 8

self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
self.tmp_result_buffer_ptrs = self.create_shared_buffer(max_size, group=group)
self.rank_data_base = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self.barrier_in_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)
self.barrier_out_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)

self._ptr = ops.init_custom_ar(
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
rank,
world_size,
self.rank_data_base,
self.buffer_ptrs,
self.tmp_result_buffer_ptrs,
self.barrier_in_ptrs,
self.barrier_out_ptrs,
)
ops.register_buffer(self._ptr, self.buffer_ptrs)
self.disabled = False

@staticmethod
def create_shared_buffer(
Expand Down Expand Up @@ -300,12 +307,25 @@ def should_custom_ar(self, inp: torch.Tensor):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
if self.world_size == 2:
return (
inp_size < self.max_size
and inp_size < self.max_required_workspace_size[0]
)

if self.full_nvlink:
return (
inp_size < self.max_size
and inp_size < self.max_required_workspace_size[1]
)

return False

def all_reduce(
self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False
self,
inp: torch.Tensor,
*,
out: torch.Tensor = None,
):
"""Performs an out-of-place all reduce.

Expand All @@ -315,12 +335,7 @@ def all_reduce(
"""
if out is None:
out = torch.empty_like(inp)
if registered:
ops.all_reduce(self._ptr, inp, out, 0, 0)
else:
ops.all_reduce(
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
)
ops.all_reduce(self._ptr, inp, out)
return out

def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
Expand All @@ -330,23 +345,22 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=True)
return self.all_reduce(input)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.empty_like(input)
else:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return self.all_reduce(input, registered=False)
return self.all_reduce(input)

def close(self):
if not self.disabled and self._ptr:
ops.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs)
self.free_shared_buffer(self.tmp_result_buffer_ptrs)
self.free_shared_buffer(self.barrier_in_ptrs)
self.free_shared_buffer(self.barrier_out_ptrs)
self._ptr = 0

def __del__(self):
self.close()
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"sampling/penaltylib",
"test_abort.py",
"test_chunked_prefill.py",
"test_custom_allreduce.py",
"test_double_sparsity.py",
"test_eagle_infer.py",
"test_embedding_openai_server.py",
Expand Down
Loading
Loading