Skip to content

Commit

Permalink
adapt custom allreduce for tensorrt llm
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhang2077 committed Dec 18, 2024
1 parent 21e9e63 commit 71b1073
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 111 deletions.
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ runtime_common = ["aiohttp", "decord", "fastapi",
"packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart",
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
"xgrammar>=0.1.6"]
"xgrammar>=0.1.6", "sgl-kernel"]
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer>=0.1.6"]

# HIP (Heterogeneous-computing Interface for Portability) for AMD
Expand Down
47 changes: 12 additions & 35 deletions python/sglang/srt/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/_custom_ops.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import contextlib
import functools
import importlib
Expand All @@ -14,7 +14,7 @@

if not is_hpu():
try:
import custom_ar
import sgl_kernel
except ImportError as e:
logger.warning("Failed to import from custom_ar with %r", e)

Expand Down Expand Up @@ -50,46 +50,23 @@ def wrapper(*args, **kwargs):

# custom ar
def init_custom_ar(
ipc_tensors: List[torch.Tensor],
rank_data: torch.Tensor,
rank: int,
full_nvlink: bool,
rank_id: int,
world_size: int,
buffers: List[int],
barrier_in: List[int],
barrier_out: List[int],
) -> int:
return torch.ops._C_vllm_ar.init_custom_ar(
ipc_tensors, rank_data, rank, full_nvlink
return sgl_kernel.ops.init_custom_reduce(
rank_id, world_size, buffers, barrier_in, barrier_out
)


def all_reduce(
fa: int,
inp: torch.Tensor,
out: torch.Tensor,
reg_buffer: int,
reg_buffer_sz_bytes: int,
) -> None:
torch.ops._C_vllm_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.ops.custom_reduce(fa, inp, out)


def dispose(fa: int) -> None:
torch.ops._C_vllm_ar.dispose(fa)


def meta_size() -> int:
return torch.ops._C_vllm_ar.meta_size()


def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
return torch.ops._C_vllm_ar.register_buffer(fa, ipc_tensors)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return torch.ops._C_vllm_ar.get_graph_buffer_ipc_meta(fa)


def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
torch.ops._C_vllm_ar.register_graph_buffers(fa, handles, offsets)
sgl_kernel.ops.custom_dispose(fa)


# temporary fix for https://github.com/vllm-project/vllm/issues/5456
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
from sglang.srt.utils import cuda_device_count_stateless, is_cuda

try:
ops.meta_size()
import sgl_kernel

custom_ar = True
except Exception:
# For AMD GPUs and CPUs
custom_ar = False

logger = logging.getLogger(__name__)


_P = ParamSpec("_P")
_R = TypeVar("_R")

Expand All @@ -47,7 +47,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:


@with_nvml_context
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
Expand Down Expand Up @@ -196,32 +196,33 @@ def __init__(
)
return

self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(
ops.meta_size() + max_size, group=group
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.full_nvlink = full_nvlink

# From TensorRT-LLM getMaxRequiredWorkspaceSize
self.max_required_workspace_size = [16 * 1000 * 1000, 8 * 1000 * 1000]

# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self.barrier_max_size = 8 * (24 + 2) * 8

self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
self.barrier_in_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)
self.barrier_out_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)

self._ptr = ops.init_custom_ar(
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
rank,
world_size,
self.buffer_ptrs,
self.barrier_in_ptrs,
self.barrier_out_ptrs,
)
ops.register_buffer(self._ptr, self.buffer_ptrs)
self.disabled = False

@staticmethod
def create_shared_buffer(
Expand Down Expand Up @@ -258,36 +259,11 @@ def free_shared_buffer(

@contextmanager
def capture(self):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()

def register_graph_buffers(self):
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(
all_data[i], src=rank, group=self.group, device="cpu"
)
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
ops.register_graph_buffers(self._ptr, handles, offsets)

def should_custom_ar(self, inp: torch.Tensor):
if self.disabled:
Expand All @@ -300,28 +276,19 @@ def should_custom_ar(self, inp: torch.Tensor):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
return False

def all_reduce(
self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False
):
"""Performs an out-of-place all reduce.
if self.world_size == 2:
return (
inp_size < self.max_size
and inp_size < self.max_required_workspace_size[0]
)

If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if out is None:
out = torch.empty_like(inp)
if registered:
ops.all_reduce(self._ptr, inp, out, 0, 0)
else:
ops.all_reduce(
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
if self.full_nvlink:
return (
inp_size < self.max_size
and inp_size < self.max_required_workspace_size[1]
)
return out

return False

def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
"""The main allreduce API that provides support for cuda graph."""
Expand All @@ -330,23 +297,25 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=True)
output = torch.empty_like(input)
ops.all_reduce(self._ptr, input, output)
return output
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.empty_like(input)
else:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return self.all_reduce(input, registered=False)
output = torch.empty_like(input)
ops.all_reduce(self._ptr, input, output)
return output

def close(self):
if not self.disabled and self._ptr:
ops.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs)
self.free_shared_buffer(self.barrier_in_ptrs)
self.free_shared_buffer(self.barrier_out_ptrs)
self._ptr = 0

def __del__(self):
self.close()
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"sampling/penaltylib",
"test_abort.py",
"test_chunked_prefill.py",
"test_custom_allreduce.py",
"test_double_sparsity.py",
"test_embedding_openai_server.py",
"test_eval_accuracy_mini.py",
Expand Down
Loading

0 comments on commit 71b1073

Please sign in to comment.