From 9c6130a6fb999c96d18567d48d4dd723fe36c547 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 10:57:55 -0700 Subject: [PATCH 01/43] add cache for loading the same library multiple times --- .../device_communicators/pynccl_wrapper.py | 253 ++++++++++++++++++ 1 file changed, 253 insertions(+) create mode 100644 vllm/distributed/device_communicators/pynccl_wrapper.py diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py new file mode 100644 index 0000000000000..c631e223877c8 --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -0,0 +1,253 @@ +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +from vllm.logger import init_logger +from vllm.utils import find_nccl_library, nccl_integrity_check + +logger = init_logger(__name__) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("ncclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + # load the library in another process. + # if it core dumps, it will not crash the current process + nccl_integrity_check(so_file) + except Exception as e: + logger.error( + "Failed to load NCCL library from %s ." + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s." + "One solution is to download libnccl2 version 2.18 from " + "https://developer.download.nvidia.com/compute/cuda/repos/ " + "and extract the libnccl.so.2 file. If you already have the " + "library, please set the environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] From 149324321ef20e0a3f79347815e5d92f5df3eb5a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 13:24:24 -0700 Subject: [PATCH 02/43] refactor code --- .../device_communicators/pynccl.py | 238 ++---------------- .../device_communicators/pynccl_utils.py | 5 +- .../device_communicators/pynccl_wrapper.py | 5 + 3 files changed, 32 insertions(+), 216 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 758994352e3de..52c86badf57bc 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -1,26 +1,3 @@ -# This file is a pure Python wrapper for the NCCL library. -# The main purpose is to use NCCL combined with CUDA graph. -# Before writing this script, we tried the following approach: -# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself -# often gets stuck when initializing the NCCL communicator. -# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` -# contains many other potential cuda APIs, that are not allowed during -# capturing the CUDA graph. For further details, please check -# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . -# -# Another rejected idea is to write a C/C++ binding for NCCL. It is usually -# doable, but we often encounter issues related with nccl versions, and need -# to switch between different versions of NCCL. See -# https://github.com/NVIDIA/nccl/issues/1234 for more details. -# A C/C++ binding is not flexible enough to handle this. It requires -# recompilation of the code every time we want to switch between different -# versions. This current implementation, with a **pure** Python wrapper, is -# more flexible. We can easily switch between different versions of NCCL by -# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` -# variable in the code. - -import ctypes -import platform from typing import Optional, Union # ===================== import region ===================== @@ -28,188 +5,14 @@ import torch.distributed as dist from torch.distributed import ProcessGroup, ReduceOp +from vllm.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, + ncclRedOpTypeEnum, ncclUniqueId) from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank from vllm.logger import init_logger -from vllm.utils import find_nccl_library, nccl_integrity_check logger = init_logger(__name__) -so_file = find_nccl_library() - -try: - # load the library in another process. - # if it core dumps, it will not crash the current process - nccl_integrity_check(so_file) - nccl = ctypes.CDLL(so_file) -except Exception as e: - logger.error( - "Failed to load NCCL library from %s ." - "It is expected if you are not running on NVIDIA/AMD GPUs." - "Otherwise, the nccl library might not exist, be corrupted " - "or it does not support the current platform %s." - "One solution is to download libnccl2 version 2.18 from " - "https://developer.download.nvidia.com/compute/cuda/repos/ " - "and extract the libnccl.so.2 file. If you already have the " - "library, please set the environment variable VLLM_NCCL_SO_PATH" - " to point to the correct nccl library path.", so_file, - platform.platform()) - raise e - -# === export types and functions from nccl to Python === -# for the original nccl definition, please check -# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in - -ncclResult_t = ctypes.c_int - -_c_ncclGetErrorString = nccl.ncclGetErrorString -_c_ncclGetErrorString.restype = ctypes.c_char_p -_c_ncclGetErrorString.argtypes = [ncclResult_t] - - -def NCCL_CHECK(result: ncclResult_t) -> None: - if result != 0: - error_str = _c_ncclGetErrorString(result) - error_str = error_str.decode("utf-8") - raise RuntimeError(f"NCCL error: {error_str}") - - -# equivalent to c declaration: -# ncclResult_t ncclGetVersion(int *version); -_c_ncclGetVersion = nccl.ncclGetVersion -_c_ncclGetVersion.restype = ctypes.c_int -_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] - - -def ncclGetVersion() -> str: - version = ctypes.c_int() - NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version))) - # something like 21903 --> "2.19.3" - version_str = str(version.value) - major = version_str[0].lstrip("0") - minor = version_str[1:3].lstrip("0") - patch = version_str[3:].lstrip("0") - return f"{major}.{minor}.{patch}" - - -class NcclUniqueId(ctypes.Structure): - _fields_ = [("internal", ctypes.c_byte * 128)] - - -# equivalent to c declaration: -# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); -_c_ncclGetUniqueId = nccl.ncclGetUniqueId -_c_ncclGetUniqueId.restype = ctypes.c_int -_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)] - - -def ncclGetUniqueId() -> NcclUniqueId: - unique_id = NcclUniqueId() - NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id))) - return unique_id - - -# equivalent to c declaration: -# ncclResult_t ncclCommInitRank( -# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); -# note that ncclComm_t is a pointer type, so the first argument -# is a pointer to a pointer -_c_ncclCommInitRank = nccl.ncclCommInitRank -_c_ncclCommInitRank.restype = ctypes.c_int -_c_ncclCommInitRank.argtypes = [ - ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int -] - -ncclDataType_t = ctypes.c_int - - -class ncclDataTypeEnum: - ncclInt8 = 0 - ncclChar = 0 - ncclUint8 = 1 - ncclInt32 = 2 - ncclInt = 2 - ncclUint32 = 3 - ncclInt64 = 4 - ncclUint64 = 5 - ncclFloat16 = 6 - ncclHalf = 6 - ncclFloat32 = 7 - ncclFloat = 7 - ncclFloat64 = 8 - ncclDouble = 8 - ncclBfloat16 = 9 - ncclNumTypes = 10 - - @classmethod - def from_torch(cls, dtype: torch.dtype) -> int: - if dtype == torch.int8: - return cls.ncclInt8 - if dtype == torch.uint8: - return cls.ncclUint8 - if dtype == torch.int32: - return cls.ncclInt32 - if dtype == torch.int64: - return cls.ncclInt64 - if dtype == torch.float16: - return cls.ncclFloat16 - if dtype == torch.float32: - return cls.ncclFloat32 - if dtype == torch.float64: - return cls.ncclFloat64 - if dtype == torch.bfloat16: - return cls.ncclBfloat16 - raise ValueError(f"Unsupported dtype: {dtype}") - - -ncclRedOp_t = ctypes.c_int - - -class ncclRedOpTypeEnum: - ncclSum = 0 - ncclProd = 1 - ncclMax = 2 - ncclMin = 3 - ncclAvg = 4 - ncclNumOps = 5 - - @classmethod - def from_torch(cls, op: ReduceOp) -> int: - if op == ReduceOp.SUM: - return cls.ncclSum - if op == ReduceOp.PRODUCT: - return cls.ncclProd - if op == ReduceOp.MAX: - return cls.ncclMax - if op == ReduceOp.MIN: - return cls.ncclMin - if op == ReduceOp.AVG: - return cls.ncclAvg - raise ValueError(f"Unsupported op: {op}") - - -# equivalent to c declaration: -# ncclResult_t ncclAllReduce( -# const void* sendbuff, void* recvbuff, size_t count, -# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, -# udaStream_t stream); -# note that cudaStream_t is a pointer type, so the last argument is a pointer -_c_ncclAllReduce = nccl.ncclAllReduce -_c_ncclAllReduce.restype = ctypes.c_int -_c_ncclAllReduce.argtypes = [ - ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t, - ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p -] - -# be cautious! this is a collective call, it will block until all -# processes in the communicator have called this function. -# because Python object destruction can happen in random order, -# it is better not to call it at all. -# equivalent to c declaration: -# ncclResult_t ncclCommDestroy(ncclComm_t comm); -_c_ncclCommDestroy = nccl.ncclCommDestroy -_c_ncclCommDestroy.restype = ctypes.c_int -_c_ncclCommDestroy.argtypes = [ctypes.c_void_p] - class NCCLCommunicator: @@ -217,6 +20,7 @@ def __init__( self, group: Optional[ProcessGroup] = None, device: Optional[Union[int, str, torch.device]] = None, + library_path: Optional[str] = None, ): """ Args: @@ -224,9 +28,17 @@ def __init__( default process group. device: the device to bind the NCCLCommunicator to. If None, it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. It is the caller's responsibility to make sure each communicator is bind to a unique device. """ + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + self.disabled = True + return + self.disabled = False assert dist.is_initialized() group = get_cpu_world_group() if group is None else group assert dist.get_backend(group) != dist.Backend.NCCL, ( @@ -236,9 +48,11 @@ def __init__( self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) if self.rank == 0: - self.unique_id = ncclGetUniqueId() + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() else: - self.unique_id = NcclUniqueId() + # construct an empty unique id + self.unique_id = ncclUniqueId() tensor = torch.ByteTensor(list(self.unique_id.internal)) ranks = dist.get_process_group_ranks(group) # arg `src` in `broadcast` is the global rank @@ -246,7 +60,6 @@ def __init__( byte_list = tensor.tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte - self.comm = ctypes.c_void_p() if device is None: local_rank = get_local_rank() device = torch.device(f"cuda:{local_rank}") @@ -261,15 +74,16 @@ def __init__( # `torch.cuda.device` is a context manager that changes the # current cuda device to the specified one with torch.cuda.device(device): - NCCL_CHECK( - _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, - self.unique_id, self.rank)) + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank) self.stream = torch.cuda.Stream() def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None): + if self.disabled: + return # nccl communicator created on a specific device # will only work on tensors on the same device # otherwise it will cause "illegal memory access" @@ -278,10 +92,8 @@ def all_reduce(self, f"but the input tensor is on {tensor.device}") if stream is None: stream = self.stream - NCCL_CHECK( - _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), - ctypes.c_void_p(tensor.data_ptr()), - tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - ctypes.c_void_p(stream.cuda_stream))) + self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()), + buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py index 44e4f39217a41..dc4fb186d48d2 100644 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ b/vllm/distributed/device_communicators/pynccl_utils.py @@ -9,8 +9,7 @@ logger = init_logger(__name__) try: - from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, - ncclGetVersion) + from vllm.distributed.device_communicators.pynccl import NCCLCommunicator except Exception as e: # in non-NVIDIA environments, we can't import the nccl module # e.g. when running on machines with AMD GPUs @@ -40,8 +39,8 @@ def set_pynccl_stream(stream: torch.cuda.Stream): def init_process_group(group: Optional[ProcessGroup] = None) -> None: assert not is_initialized() global comm - logger.info("vLLM is using nccl==%s", ncclGetVersion()) comm = NCCLCommunicator(group=group) + logger.info("vLLM is using nccl==%s", comm.nccl.ncclGetVersion()) def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index c631e223877c8..43d85674b23d0 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -150,6 +150,11 @@ class NCCLLibrary: buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. # ncclResult_t ncclCommDestroy(ncclComm_t comm); Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), ] From cadcd02619a3efb7d56b56019ca988e3a62c2f5e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 15:19:16 -0700 Subject: [PATCH 03/43] fix import --- tests/distributed/test_pynccl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index b6f461b76ed03..0137edffb4aa0 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -5,8 +5,8 @@ import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils from vllm.distributed.communication_op import tensor_model_parallel_all_reduce -from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, - ncclGetUniqueId) +from vllm.distributed.device_communicators.pynccl import NCCLCommunicator +from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.parallel_state import ( ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group, init_distributed_environment, with_pynccl_for_all_reduce) @@ -147,7 +147,8 @@ def test_pynccl_with_cudagraph(): def test_ncclGetUniqueId(): - unique_id = ncclGetUniqueId() + lib = NCCLLibrary() + unique_id = lib.ncclGetUniqueId() # `list(unique_id.internal)` is something like this: # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, From 79187983ea16eb5e129372eea3601bafe2c8d1fa Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 15:33:06 -0700 Subject: [PATCH 04/43] remove pynccl_utils.init_process_group --- tests/distributed/test_pynccl.py | 9 ++---- .../device_communicators/pynccl_utils.py | 9 +----- vllm/distributed/parallel_state.py | 30 +++++++++++++------ vllm/worker/worker.py | 15 ---------- 4 files changed, 25 insertions(+), 38 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 0137edffb4aa0..083ca6012efbd 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -3,13 +3,12 @@ import pytest import torch -import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.device_communicators.pynccl import NCCLCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary -from vllm.distributed.parallel_state import ( - ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group, - init_distributed_environment, with_pynccl_for_all_reduce) +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment, + with_pynccl_for_all_reduce) from vllm.utils import update_environment_variables @@ -97,8 +96,6 @@ def multiple_tp_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") torch.cuda.set_device(torch.distributed.get_rank()) ensure_model_parallel_initialized(2, 2) - pynccl_utils.init_process_group( - group=get_tensor_model_parallel_cpu_group()) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) with with_pynccl_for_all_reduce(): # two tp groups can communicate independently diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py index dc4fb186d48d2..cffecd5422eed 100644 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ b/vllm/distributed/device_communicators/pynccl_utils.py @@ -2,7 +2,7 @@ from typing import Optional import torch -from torch.distributed import ProcessGroup, ReduceOp +from torch.distributed import ReduceOp from vllm.logger import init_logger @@ -36,13 +36,6 @@ def set_pynccl_stream(stream: torch.cuda.Stream): pass -def init_process_group(group: Optional[ProcessGroup] = None) -> None: - assert not is_initialized() - global comm - comm = NCCLCommunicator(group=group) - logger.info("vLLM is using nccl==%s", comm.nccl.ncclGetVersion()) - - def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: """All-reduces the input tensor across the process group.""" assert input_.is_cuda, f"{input_} should be a cuda tensor" diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index be5bb4e857caf..774cb1b44da47 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -4,9 +4,10 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" import contextlib -from typing import Optional +from typing import List, Optional import torch +from torch.distributed import ProcessGroup import vllm.envs as envs from vllm.logger import init_logger @@ -14,10 +15,11 @@ logger = init_logger(__name__) # Tensor model parallel group that the current rank belongs to. -_TP_DEVICE_GROUP = None -_TP_CPU_GROUP = None +_TP_DEVICE_GROUP: Optional[ProcessGroup] = None +_TP_CPU_GROUP: Optional[ProcessGroup] = None +_TP_PYNCCL_COMMUNICATOR = None # Pipeline model parallel group that the current rank belongs to. -_PIPELINE_MODEL_PARALLEL_GROUP = None +_PIPELINE_MODEL_PARALLEL_GROUP: Optional[ProcessGroup] = None # when people blindly call `torch.distributed.all_reduce` etc, # it will use this group. It is initialized with the `backend` @@ -41,7 +43,7 @@ # A list of global ranks for each pipeline group to ease calculation of the # source rank when broadcasting from the first or last pipeline stage. -_PIPELINE_GLOBAL_RANKS = None +_PIPELINE_GLOBAL_RANKS: Optional[List[int]] = None _LOCAL_RANK = -1 @@ -133,25 +135,35 @@ def initialize_model_parallel( rank = torch.distributed.get_rank() # Build the tensor model-parallel groups. - global _TP_DEVICE_GROUP, _TP_CPU_GROUP + global _TP_DEVICE_GROUP, _TP_CPU_GROUP, _TP_PYNCCL_COMMUNICATOR assert _TP_DEVICE_GROUP is None, ( "tensor model parallel group is already initialized") for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, - (i + 1) * tensor_model_parallel_size) + ranks = list( + range(i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size)) group = torch.distributed.new_group(ranks, backend=backend) cpu_group = torch.distributed.new_group(ranks, backend="gloo") if rank in ranks: _TP_DEVICE_GROUP = group _TP_CPU_GROUP = cpu_group + from vllm.distributed.device_communicators.pynccl import NCCLCommunicator + _TP_PYNCCL_COMMUNICATOR = NCCLCommunicator(group=_TP_CPU_GROUP) + + logger.info("vLLM is using nccl==%s", + _TP_PYNCCL_COMMUNICATOR.nccl.ncclGetVersion()) + + from vllm.distributed.device_communicators import pynccl_utils + pynccl_utils.comm = _TP_PYNCCL_COMMUNICATOR + # Build the pipeline model-parallel groups. global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_GLOBAL_RANKS assert _PIPELINE_MODEL_PARALLEL_GROUP is None, ( "pipeline model parallel group is already initialized") for i in range(num_pipeline_model_parallel_groups): - ranks = range(i, world_size, num_pipeline_model_parallel_groups) + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group = torch.distributed.new_group(ranks, backend=backend) if rank in ranks: _PIPELINE_MODEL_PARALLEL_GROUP = group diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 808261e47318b..17778a965b03a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -11,7 +11,6 @@ VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, - get_tensor_model_parallel_cpu_group, init_distributed_environment) from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators.custom_all_reduce import ( @@ -292,20 +291,6 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - if pynccl_utils.is_initialized(): - pynccl_world_size = pynccl_utils.get_world_size() - if pynccl_world_size != parallel_config.world_size: - raise RuntimeError( - "pynccl is already initialized but the pynccl world " - "size does not match parallel_config.world_size " - f"({pynccl_world_size} vs. {parallel_config.world_size}).") - elif parallel_config.world_size > 1: - # NOTE(woosuk): We don't initialize pynccl process group when world size - # is 1. - # NOTE(kaichao): By default, pynccl is initialized for tp group. - pynccl_utils.init_process_group( - group=get_tensor_model_parallel_cpu_group()) - # Initialize a custom fast all-reduce implementation. if not parallel_config.disable_custom_all_reduce: init_custom_ar() From 592403849b3892836d5918eb596a740fc468acd8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 15:34:40 -0700 Subject: [PATCH 05/43] remove pynccl_utils.is_initialized --- vllm/distributed/device_communicators/pynccl_utils.py | 5 ----- vllm/worker/model_runner.py | 3 +-- vllm/worker/worker.py | 3 +-- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py index cffecd5422eed..3de53737b8239 100644 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ b/vllm/distributed/device_communicators/pynccl_utils.py @@ -20,11 +20,6 @@ comm: Optional["NCCLCommunicator"] = None -def is_initialized() -> bool: - """Returns whether the NCCL backend is initialized.""" - return comm is not None - - @contextlib.contextmanager def set_pynccl_stream(stream: torch.cuda.Stream): """Set the cuda stream for communication""" diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bbb1f5205af5e..d0e271b4dd9ac 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1056,8 +1056,7 @@ def __call__(self, *args, **kwargs): @contextlib.contextmanager def _maybe_pynccl(): - if pynccl_utils.is_initialized( - ) and not custom_all_reduce.is_initialized(): + if not custom_all_reduce.is_initialized(): with with_pynccl_for_all_reduce(): yield else: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 17778a965b03a..f6dcb2970be5f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -297,8 +297,7 @@ def init_worker_distributed_environment( # A small all_reduce for warmup. torch.distributed.all_reduce(torch.zeros(1).cuda()) - if pynccl_utils.is_initialized(): - pynccl_utils.all_reduce(torch.zeros(1).cuda()) + pynccl_utils.all_reduce(torch.zeros(1).cuda()) def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): From 813b0476b6b5c2fd08cc7b85e3b4dd256369c214 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 15:36:32 -0700 Subject: [PATCH 06/43] remove pynccl_utils.destroy_process_group --- vllm/distributed/device_communicators/pynccl_utils.py | 5 ----- vllm/distributed/parallel_state.py | 8 ++++---- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py index 3de53737b8239..f2950e5eee63f 100644 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ b/vllm/distributed/device_communicators/pynccl_utils.py @@ -38,11 +38,6 @@ def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: comm.all_reduce(input_, op) -def destroy_process_group() -> None: - global comm - comm = None - - def get_world_size() -> int: """Returns the world size.""" assert comm is not None diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 774cb1b44da47..0c55af2c55629 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -307,16 +307,16 @@ def destroy_model_parallel(): if _TP_CPU_GROUP: torch.distributed.destroy_process_group(_TP_CPU_GROUP) _TP_CPU_GROUP = None + from vllm.distributed.device_communicators import pynccl_utils + del pynccl_utils.comm + pynccl_utils.comm = None + global _PIPELINE_MODEL_PARALLEL_GROUP if _PIPELINE_MODEL_PARALLEL_GROUP: torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP) _PIPELINE_MODEL_PARALLEL_GROUP = None global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = None - from vllm.distributed.device_communicators import pynccl_utils - - # Destroy the pynccl states if any. - pynccl_utils.destroy_process_group() # Whether to use pynccl for nccl all reduce. From b244e6c21f0379a11d5a2e51d36a2cb5ce9f158d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 15:37:06 -0700 Subject: [PATCH 07/43] remove pynccl_utils.get_world_size --- vllm/distributed/device_communicators/pynccl_utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py index f2950e5eee63f..2547908f8ff59 100644 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ b/vllm/distributed/device_communicators/pynccl_utils.py @@ -38,11 +38,5 @@ def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: comm.all_reduce(input_, op) -def get_world_size() -> int: - """Returns the world size.""" - assert comm is not None - return comm.world_size - - def get_nccl_backend() -> Optional["NCCLCommunicator"]: return comm From 7e15c98c1ee96fce369c65d5c794e9350a5250ca Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 15:38:27 -0700 Subject: [PATCH 08/43] remove pynccl_utils.get_nccl_backend --- vllm/distributed/device_communicators/pynccl_utils.py | 4 ---- vllm/worker/model_runner.py | 7 +------ 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py index 2547908f8ff59..8182f72da9638 100644 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ b/vllm/distributed/device_communicators/pynccl_utils.py @@ -36,7 +36,3 @@ def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: assert input_.is_cuda, f"{input_} should be a cuda tensor" assert comm is not None comm.all_reduce(input_, op) - - -def get_nccl_backend() -> Optional["NCCLCommunicator"]: - return comm diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d0e271b4dd9ac..f647e994d1fc8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -12,8 +12,7 @@ from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce -from vllm.distributed.device_communicators import (custom_all_reduce, - pynccl_utils) +from vllm.distributed.device_communicators import custom_all_reduce from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -858,10 +857,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: Since it is used for decoding-only, it assumes there's only 1 token per sequence in the batch. """ - # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never - # deleted before the CUDA graphs. - self.pynccl_backend = pynccl_utils.get_nccl_backend() - assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " "unexpected consequences if the model is not static. To " From e610f643cbf78a5d06dc7c2eb72e6e3b0037cdeb Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 15:43:35 -0700 Subject: [PATCH 09/43] remove is_pynccl_enabled_for_all_reduce --- vllm/distributed/communication_op.py | 8 +++++--- vllm/distributed/parallel_state.py | 6 ------ 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index b539a7beedbfe..f5cf0243a9344 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -7,8 +7,7 @@ from .parallel_state import (get_cpu_world_group, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - is_pynccl_enabled_for_all_reduce) + get_tensor_model_parallel_world_size) def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: @@ -26,6 +25,8 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators.custom_all_reduce import ( custom_all_reduce) + from vllm.distributed.device_communicators.parallel_state import ( + _TP_PYNCCL_COMMUNICATOR) # Bypass the function if we are using only 1 GPU. if get_tensor_model_parallel_world_size() == 1: @@ -33,7 +34,8 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: out = custom_all_reduce(input_) if out is not None: return out - if is_pynccl_enabled_for_all_reduce(): + if _TP_PYNCCL_COMMUNICATOR is not None and \ + not _TP_PYNCCL_COMMUNICATOR.disabled: pynccl_utils.all_reduce(input_) else: torch.distributed.all_reduce(input_, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 0c55af2c55629..8bdcd9f7bd618 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -343,9 +343,3 @@ def with_pynccl_for_all_reduce(): with pynccl_utils.set_pynccl_stream(stream): yield _ENABLE_PYNCCL_FOR_ALL_REDUCE = old - - -def is_pynccl_enabled_for_all_reduce(): - """check if pynccl is enabled for all reduce""" - global _ENABLE_PYNCCL_FOR_ALL_REDUCE - return _ENABLE_PYNCCL_FOR_ALL_REDUCE From 8480995817643c0f8b2e40a21c91ab9b7ec75b4b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 15:46:59 -0700 Subject: [PATCH 10/43] remove _ENABLE_PYNCCL_FOR_ALL_REDUCE --- vllm/distributed/parallel_state.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 8bdcd9f7bd618..4fa82d4d0d01e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -319,14 +319,10 @@ def destroy_model_parallel(): _PIPELINE_GLOBAL_RANKS = None -# Whether to use pynccl for nccl all reduce. -# We use pynccl for all reduce when using CUDA graph, because torch.distributed -# is not well supported by CUDA graph. -_ENABLE_PYNCCL_FOR_ALL_REDUCE = False - - @contextlib.contextmanager def with_pynccl_for_all_reduce(): + # We use pynccl for all reduce when using CUDA graph, because + # torch.distributed is not well supported by CUDA graph. from vllm.distributed.device_communicators import pynccl_utils """use pynccl instead of torch.distributed for all reduce""" tp_size = get_tensor_model_parallel_world_size() @@ -335,11 +331,12 @@ def with_pynccl_for_all_reduce(): # NOTE(woosuk): We don't initialize pynccl when tp_size is 1. yield else: - global _ENABLE_PYNCCL_FOR_ALL_REDUCE - old = _ENABLE_PYNCCL_FOR_ALL_REDUCE - _ENABLE_PYNCCL_FOR_ALL_REDUCE = True + global _TP_PYNCCL_COMMUNICATOR + assert _TP_PYNCCL_COMMUNICATOR is not None + old = _TP_PYNCCL_COMMUNICATOR.disabled + _TP_PYNCCL_COMMUNICATOR.disabled = True stream = torch.cuda.current_stream() with pynccl_utils.set_pynccl_stream(stream): yield - _ENABLE_PYNCCL_FOR_ALL_REDUCE = old + _TP_PYNCCL_COMMUNICATOR.disabled = old From 5ed6f070cd5ce9dad499cff233f4ce53edaf8e3e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 15:49:08 -0700 Subject: [PATCH 11/43] remove set_pynccl_stream --- .../distributed/device_communicators/pynccl_utils.py | 12 ------------ vllm/distributed/parallel_state.py | 9 ++++++--- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py index 8182f72da9638..6f75682a7580d 100644 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ b/vllm/distributed/device_communicators/pynccl_utils.py @@ -1,4 +1,3 @@ -import contextlib from typing import Optional import torch @@ -20,17 +19,6 @@ comm: Optional["NCCLCommunicator"] = None -@contextlib.contextmanager -def set_pynccl_stream(stream: torch.cuda.Stream): - """Set the cuda stream for communication""" - try: - assert comm is not None - comm.stream = stream - yield - finally: - pass - - def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: """All-reduces the input tensor across the process group.""" assert input_.is_cuda, f"{input_} should be a cuda tensor" diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 4fa82d4d0d01e..9c7cd1297b483 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -323,7 +323,6 @@ def destroy_model_parallel(): def with_pynccl_for_all_reduce(): # We use pynccl for all reduce when using CUDA graph, because # torch.distributed is not well supported by CUDA graph. - from vllm.distributed.device_communicators import pynccl_utils """use pynccl instead of torch.distributed for all reduce""" tp_size = get_tensor_model_parallel_world_size() if tp_size == 1: @@ -337,6 +336,10 @@ def with_pynccl_for_all_reduce(): _TP_PYNCCL_COMMUNICATOR.disabled = True stream = torch.cuda.current_stream() - with pynccl_utils.set_pynccl_stream(stream): - yield + old_stream = _TP_PYNCCL_COMMUNICATOR.stream + + _TP_PYNCCL_COMMUNICATOR.stream = stream + yield + + _TP_PYNCCL_COMMUNICATOR.stream = old_stream _TP_PYNCCL_COMMUNICATOR.disabled = old From 8134287979be9a4222b8fdac1d3f0ba89cf9c573 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 15:54:57 -0700 Subject: [PATCH 12/43] remove pynccl utils --- vllm/distributed/communication_op.py | 3 +-- .../device_communicators/pynccl_utils.py | 26 ------------------- vllm/distributed/parallel_state.py | 11 ++++---- vllm/worker/worker.py | 5 ---- 4 files changed, 7 insertions(+), 38 deletions(-) delete mode 100644 vllm/distributed/device_communicators/pynccl_utils.py diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index f5cf0243a9344..516bde543e494 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -22,7 +22,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: TLDR: always assume this function modifies its input, but use the return value as the output. """ - from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators.custom_all_reduce import ( custom_all_reduce) from vllm.distributed.device_communicators.parallel_state import ( @@ -36,7 +35,7 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: return out if _TP_PYNCCL_COMMUNICATOR is not None and \ not _TP_PYNCCL_COMMUNICATOR.disabled: - pynccl_utils.all_reduce(input_) + _TP_PYNCCL_COMMUNICATOR.all_reduce(input_) else: torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py deleted file mode 100644 index 6f75682a7580d..0000000000000 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Optional - -import torch -from torch.distributed import ReduceOp - -from vllm.logger import init_logger - -logger = init_logger(__name__) - -try: - from vllm.distributed.device_communicators.pynccl import NCCLCommunicator -except Exception as e: - # in non-NVIDIA environments, we can't import the nccl module - # e.g. when running on machines with AMD GPUs - logger.info("Failed to import NCCL library: %s", e) - logger.info("It is expected if you are not running on NVIDIA GPUs.") - pass - -comm: Optional["NCCLCommunicator"] = None - - -def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: - """All-reduces the input tensor across the process group.""" - assert input_.is_cuda, f"{input_} should be a cuda tensor" - assert comm is not None - comm.all_reduce(input_, op) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 9c7cd1297b483..c7b01edf161ee 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -86,6 +86,8 @@ def init_distributed_environment( local_rank = envs.LOCAL_RANK global _LOCAL_RANK _LOCAL_RANK = local_rank + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) def initialize_model_parallel( @@ -154,8 +156,8 @@ def initialize_model_parallel( logger.info("vLLM is using nccl==%s", _TP_PYNCCL_COMMUNICATOR.nccl.ncclGetVersion()) - from vllm.distributed.device_communicators import pynccl_utils - pynccl_utils.comm = _TP_PYNCCL_COMMUNICATOR + # A small all_reduce for warmup. + _TP_PYNCCL_COMMUNICATOR.all_reduce(torch.zeros(1).cuda()) # Build the pipeline model-parallel groups. global _PIPELINE_MODEL_PARALLEL_GROUP @@ -307,9 +309,8 @@ def destroy_model_parallel(): if _TP_CPU_GROUP: torch.distributed.destroy_process_group(_TP_CPU_GROUP) _TP_CPU_GROUP = None - from vllm.distributed.device_communicators import pynccl_utils - del pynccl_utils.comm - pynccl_utils.comm = None + global _TP_PYNCCL_COMMUNICATOR + _TP_PYNCCL_COMMUNICATOR = None global _PIPELINE_MODEL_PARALLEL_GROUP if _PIPELINE_MODEL_PARALLEL_GROUP: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f6dcb2970be5f..18ca54c7cc75d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -12,7 +12,6 @@ from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) -from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators.custom_all_reduce import ( init_custom_ar) from vllm.lora.request import LoRARequest @@ -295,10 +294,6 @@ def init_worker_distributed_environment( if not parallel_config.disable_custom_all_reduce: init_custom_ar() - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) - pynccl_utils.all_reduce(torch.zeros(1).cuda()) - def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. From e65e9efd589302a206b782a86a5ea96d8908772d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 16:42:04 -0700 Subject: [PATCH 13/43] fix state --- vllm/distributed/parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c7b01edf161ee..e8555d7c242f3 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -334,7 +334,7 @@ def with_pynccl_for_all_reduce(): global _TP_PYNCCL_COMMUNICATOR assert _TP_PYNCCL_COMMUNICATOR is not None old = _TP_PYNCCL_COMMUNICATOR.disabled - _TP_PYNCCL_COMMUNICATOR.disabled = True + _TP_PYNCCL_COMMUNICATOR.disabled = False stream = torch.cuda.current_stream() old_stream = _TP_PYNCCL_COMMUNICATOR.stream From c8b6fc01c9f47c4c4cc74a8f1a9d2aecf6e1dd5a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 16:42:35 -0700 Subject: [PATCH 14/43] fix test --- tests/distributed/test_pynccl.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 083ca6012efbd..baa1f9268fba0 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -40,6 +40,10 @@ def worker_fn_wrapper(fn): # and update the environment variables in the function def wrapped_fn(env): update_environment_variables(env) + import os + local_rank = os.environ['LOCAL_RANK'] + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) init_distributed_environment() fn() @@ -94,7 +98,6 @@ def test_pynccl_multiple_tp(): @worker_fn_wrapper def multiple_tp_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") - torch.cuda.set_device(torch.distributed.get_rank()) ensure_model_parallel_initialized(2, 2) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) with with_pynccl_for_all_reduce(): From c7a2f0cbb7ba7ffbfd9db93d224520db0b84ccd0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 16:43:13 -0700 Subject: [PATCH 15/43] fix import --- vllm/distributed/communication_op.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 516bde543e494..d39af1d771e2a 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -24,8 +24,7 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: """ from vllm.distributed.device_communicators.custom_all_reduce import ( custom_all_reduce) - from vllm.distributed.device_communicators.parallel_state import ( - _TP_PYNCCL_COMMUNICATOR) + from vllm.distributed.parallel_state import _TP_PYNCCL_COMMUNICATOR # Bypass the function if we are using only 1 GPU. if get_tensor_model_parallel_world_size() == 1: From 75a8d1119c9917b1ea6945b763326026fd70a3d1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 20:02:49 -0700 Subject: [PATCH 16/43] move warmup into pynccl --- vllm/distributed/device_communicators/pynccl.py | 6 ++++++ vllm/distributed/parallel_state.py | 6 ------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 52c86badf57bc..bebf299b324b9 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -39,6 +39,9 @@ def __init__( self.disabled = True return self.disabled = False + + logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) + assert dist.is_initialized() group = get_cpu_world_group() if group is None else group assert dist.get_backend(group) != dist.Backend.NCCL, ( @@ -78,6 +81,9 @@ def __init__( self.world_size, self.unique_id, self.rank) self.stream = torch.cuda.Stream() + # A small all_reduce for warmup. + self.all_reduce(torch.zeros(1, device=device)) + def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e8555d7c242f3..015bb95664687 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -153,12 +153,6 @@ def initialize_model_parallel( from vllm.distributed.device_communicators.pynccl import NCCLCommunicator _TP_PYNCCL_COMMUNICATOR = NCCLCommunicator(group=_TP_CPU_GROUP) - logger.info("vLLM is using nccl==%s", - _TP_PYNCCL_COMMUNICATOR.nccl.ncclGetVersion()) - - # A small all_reduce for warmup. - _TP_PYNCCL_COMMUNICATOR.all_reduce(torch.zeros(1).cuda()) - # Build the pipeline model-parallel groups. global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_GLOBAL_RANKS From 59c064eeb6c782e3153c13391cbd07fd4f9c1ff2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 20:04:06 -0700 Subject: [PATCH 17/43] add device --- vllm/distributed/parallel_state.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 015bb95664687..525f447e78999 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -151,7 +151,8 @@ def initialize_model_parallel( _TP_CPU_GROUP = cpu_group from vllm.distributed.device_communicators.pynccl import NCCLCommunicator - _TP_PYNCCL_COMMUNICATOR = NCCLCommunicator(group=_TP_CPU_GROUP) + _TP_PYNCCL_COMMUNICATOR = NCCLCommunicator(group=_TP_CPU_GROUP, + device=_LOCAL_RANK) # Build the pipeline model-parallel groups. global _PIPELINE_MODEL_PARALLEL_GROUP From 16aeef10c210cca169d08264872461892bf9e0b1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 20:08:33 -0700 Subject: [PATCH 18/43] fix device for allreduce warmup --- vllm/distributed/parallel_state.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 525f447e78999..e255175d74ecd 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -87,7 +87,10 @@ def init_distributed_environment( global _LOCAL_RANK _LOCAL_RANK = local_rank # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) + data = torch.zeros(1) + if torch.cuda.is_available(): + data = data.to(device=f"cuda:{local_rank}") + torch.distributed.all_reduce(data) def initialize_model_parallel( From 4710fc3ea6745089b806dfbf1587d8cdc70ecbd4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 20:20:51 -0700 Subject: [PATCH 19/43] improve ways of discovering default local rank --- vllm/distributed/parallel_state.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e255175d74ecd..b7df2e0a22f9e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -82,8 +82,13 @@ def init_distributed_environment( # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 - if local_rank == -1 and distributed_init_method == "env://": - local_rank = envs.LOCAL_RANK + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank global _LOCAL_RANK _LOCAL_RANK = local_rank # A small all_reduce for warmup. From c8542eca70a019c8818091f374c9332446797ce6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 20:22:41 -0700 Subject: [PATCH 20/43] make sure warmup happens in stream --- vllm/distributed/device_communicators/pynccl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index bebf299b324b9..2f5e54daf4210 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -81,8 +81,9 @@ def __init__( self.world_size, self.unique_id, self.rank) self.stream = torch.cuda.Stream() - # A small all_reduce for warmup. - self.all_reduce(torch.zeros(1, device=device)) + # A small all_reduce for warmup. + self.all_reduce(torch.zeros(1, device=device)) + self.stream.synchronize() def all_reduce(self, tensor: torch.Tensor, From b2d2661dd3a24fc922a058eff493964b7b2e41f5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 21:07:15 -0700 Subject: [PATCH 21/43] add disable --- vllm/distributed/device_communicators/pynccl.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 2f5e54daf4210..337ef0f94e88a 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -21,6 +21,7 @@ def __init__( group: Optional[ProcessGroup] = None, device: Optional[Union[int, str, torch.device]] = None, library_path: Optional[str] = None, + disabled: bool = False, ): """ Args: @@ -30,14 +31,22 @@ def __init__( it will be bind to f"cuda:{local_rank}". library_path: the path to the NCCL library. If None, it will use the default library path. + disabled: if True, the communicator will be disabled. This object + will not do anything, just serve as a placeholder. It is the caller's responsibility to make sure each communicator is bind to a unique device. """ + # explicit disable, e.g. world_size == 1 + self.disabled = disabled + if disabled: + return try: self.nccl = NCCLLibrary(library_path) except Exception: self.disabled = True return + # disable because of missing NCCL library + # e.g. in a non-GPU environment self.disabled = False logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) From 67d1d9a68c465690a1e93583fe24d749aaaa4ab5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 21:09:28 -0700 Subject: [PATCH 22/43] do not init when world size is 1 --- vllm/distributed/parallel_state.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b7df2e0a22f9e..ebc25e9c124db 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -160,7 +160,8 @@ def initialize_model_parallel( from vllm.distributed.device_communicators.pynccl import NCCLCommunicator _TP_PYNCCL_COMMUNICATOR = NCCLCommunicator(group=_TP_CPU_GROUP, - device=_LOCAL_RANK) + device=_LOCAL_RANK, + disabled=world_size == 1) # Build the pipeline model-parallel groups. global _PIPELINE_MODEL_PARALLEL_GROUP From c86199c9ce25f918f30c6736a26cd1452556cd2e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 21:46:35 -0700 Subject: [PATCH 23/43] fix initial state of pynccl allreduce --- vllm/distributed/parallel_state.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index ebc25e9c124db..41b7b9d7b8091 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -163,6 +163,10 @@ def initialize_model_parallel( device=_LOCAL_RANK, disabled=world_size == 1) + # by default it is disabled, e.g. in profiling models + # to use it, we have to use under `with_pynccl_for_all_reduce` + _TP_PYNCCL_COMMUNICATOR.disabled = True + # Build the pipeline model-parallel groups. global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_GLOBAL_RANKS From 0030a317c88a6e37f69ab74a3ecb25e7372a1701 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 23:03:41 -0700 Subject: [PATCH 24/43] add comments --- vllm/distributed/parallel_state.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 41b7b9d7b8091..d30ac48439f56 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -159,12 +159,14 @@ def initialize_model_parallel( _TP_CPU_GROUP = cpu_group from vllm.distributed.device_communicators.pynccl import NCCLCommunicator - _TP_PYNCCL_COMMUNICATOR = NCCLCommunicator(group=_TP_CPU_GROUP, - device=_LOCAL_RANK, - disabled=world_size == 1) - - # by default it is disabled, e.g. in profiling models - # to use it, we have to use under `with_pynccl_for_all_reduce` + _TP_PYNCCL_COMMUNICATOR = NCCLCommunicator( + group=_TP_CPU_GROUP, + device=_LOCAL_RANK, + disabled=tensor_model_parallel_size == 1) + + # by default it is disabled, e.g. in profiling models and prefill phase. + # to use it, we have to use under `with_pynccl_for_all_reduce`, usually + # when we are using CUDA graph. _TP_PYNCCL_COMMUNICATOR.disabled = True # Build the pipeline model-parallel groups. From 49f6d911cae732a704a0918b7009f98cebf9e07b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 23:18:13 -0700 Subject: [PATCH 25/43] add context manager --- .../device_communicators/pynccl.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 337ef0f94e88a..f44ae895f7cfe 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from typing import Optional, Union # ===================== import region ===================== @@ -39,14 +40,16 @@ def __init__( # explicit disable, e.g. world_size == 1 self.disabled = disabled if disabled: + self.available = False return try: self.nccl = NCCLLibrary(library_path) except Exception: - self.disabled = True + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False return - # disable because of missing NCCL library - # e.g. in a non-GPU environment + self.available = True self.disabled = False logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) @@ -113,3 +116,27 @@ def all_reduce(self, ncclDataTypeEnum.from_torch(tensor.dtype), ncclRedOpTypeEnum.from_torch(op), self.comm, cudaStream_t(stream.cuda_stream)) + + @contextmanager + def enable(self, + enable: Optional[bool] = None, + stream: Optional[torch.cuda.Stream] = None): + """ + A context manager to enable or disable the communicator. + """ + if enable is None: + # guess a default value when not specified + enable = self.world_size > 1 and self.available + + if stream is None: + stream = torch.cuda.current_stream() + + old_disable = self.disabled + old_stream = self.stream + + self.stream = stream + self.disabled = not enable + yield + + self.disabled = old_disable + self.stream = old_stream From 38b148b164998aac72aebeb08985c04fe0fae716 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 23:23:32 -0700 Subject: [PATCH 26/43] refactor logic of available --- .../device_communicators/pynccl.py | 30 +++++++++---------- vllm/distributed/parallel_state.py | 2 +- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index f44ae895f7cfe..c431c51e86397 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -22,7 +22,6 @@ def __init__( group: Optional[ProcessGroup] = None, device: Optional[Union[int, str, torch.device]] = None, library_path: Optional[str] = None, - disabled: bool = False, ): """ Args: @@ -32,15 +31,22 @@ def __init__( it will be bind to f"cuda:{local_rank}". library_path: the path to the NCCL library. If None, it will use the default library path. - disabled: if True, the communicator will be disabled. This object - will not do anything, just serve as a placeholder. It is the caller's responsibility to make sure each communicator is bind to a unique device. """ - # explicit disable, e.g. world_size == 1 - self.disabled = disabled - if disabled: + assert dist.is_initialized() + group = get_cpu_world_group() if group is None else group + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "NCCLCommunicator should be attached to a non-NCCL group.") + self.group = group + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + + # if world_size == 1, no need to create communicator + if self.world_size == 1: self.available = False + self.disabled = True return try: self.nccl = NCCLLibrary(library_path) @@ -48,20 +54,14 @@ def __init__( # disable because of missing NCCL library # e.g. in a non-GPU environment self.available = False + self.disabled = True return + self.available = True self.disabled = False logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) - assert dist.is_initialized() - group = get_cpu_world_group() if group is None else group - assert dist.get_backend(group) != dist.Backend.NCCL, ( - "NCCLCommunicator should be attached to a non-NCCL group.") - self.group = group - # note: this rank is the rank in the group - self.rank = dist.get_rank(group) - self.world_size = dist.get_world_size(group) if self.rank == 0: # get the unique id from NCCL self.unique_id = self.nccl.ncclGetUniqueId() @@ -126,7 +126,7 @@ def enable(self, """ if enable is None: # guess a default value when not specified - enable = self.world_size > 1 and self.available + enable = self.available if stream is None: stream = torch.cuda.current_stream() diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d30ac48439f56..797f6eaa66967 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -162,7 +162,7 @@ def initialize_model_parallel( _TP_PYNCCL_COMMUNICATOR = NCCLCommunicator( group=_TP_CPU_GROUP, device=_LOCAL_RANK, - disabled=tensor_model_parallel_size == 1) + ) # by default it is disabled, e.g. in profiling models and prefill phase. # to use it, we have to use under `with_pynccl_for_all_reduce`, usually From d2414803f92cc8c05b4fa70892accc74145eb865 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 23:27:15 -0700 Subject: [PATCH 27/43] non-intrusive code --- vllm/distributed/device_communicators/pynccl.py | 5 +++++ vllm/distributed/parallel_state.py | 5 ----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index c431c51e86397..e0c3e18187dbc 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -97,6 +97,11 @@ def __init__( self.all_reduce(torch.zeros(1, device=device)) self.stream.synchronize() + # by default it is disabled, e.g. in profiling models and prefill phase. + # to use it, we have to use under `with obj.enable()`, usually + # when we are using CUDA graph. + self.disabled = True + def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 797f6eaa66967..6a7cbea0b4f5f 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -164,11 +164,6 @@ def initialize_model_parallel( device=_LOCAL_RANK, ) - # by default it is disabled, e.g. in profiling models and prefill phase. - # to use it, we have to use under `with_pynccl_for_all_reduce`, usually - # when we are using CUDA graph. - _TP_PYNCCL_COMMUNICATOR.disabled = True - # Build the pipeline model-parallel groups. global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_GLOBAL_RANKS From d7209f12da4fc7d59931424a6dc0b53260bd6de8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 23:48:22 -0700 Subject: [PATCH 28/43] clean up pynccl enable or disable --- tests/distributed/test_pynccl.py | 34 ++++++++++--------- vllm/distributed/communication_op.py | 14 ++++++++ .../device_communicators/pynccl.py | 2 +- vllm/distributed/parallel_state.py | 27 --------------- vllm/worker/model_runner.py | 17 +++------- 5 files changed, 37 insertions(+), 57 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index baa1f9268fba0..e51676bcccad2 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -3,12 +3,12 @@ import pytest import torch -from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_reduce, use_pynccl_allreduce) from vllm.distributed.device_communicators.pynccl import NCCLCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - init_distributed_environment, - with_pynccl_for_all_reduce) + init_distributed_environment) from vllm.utils import update_environment_variables @@ -54,7 +54,8 @@ def wrapped_fn(env): def worker_fn(): comm = NCCLCommunicator() tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) - comm.all_reduce(tensor) + with comm.enable(): + comm.all_reduce(tensor) result = tensor.mean().cpu().item() assert result == comm.world_size @@ -75,16 +76,17 @@ def multiple_tp_worker_fn(): group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] comm = NCCLCommunicator(group=group, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - # two groups can communicate independently - if torch.distributed.get_rank() in [0, 1]: - comm.all_reduce(tensor) - comm.all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == 4 - else: - comm.all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == 2 + with comm.enable(): + # two groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + comm.all_reduce(tensor) + comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 4 + else: + comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 2 @pytest.mark.skipif(torch.cuda.device_count() < 4, @@ -100,7 +102,7 @@ def multiple_tp_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") ensure_model_parallel_initialized(2, 2) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with with_pynccl_for_all_reduce(): + with use_pynccl_allreduce(): # two tp groups can communicate independently if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) @@ -129,7 +131,7 @@ def worker_fn_with_cudagraph(): # run something in the default stream to initialize torch engine a = torch.ones((4, 4), device=f'cuda:{comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph(graph, stream=comm.stream): + with torch.cuda.graph(graph, stream=comm.stream), comm.enable(): # operation during the graph capture is recorded but not executed # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa comm.all_reduce(a) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index d39af1d771e2a..2ee188654ab59 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -3,6 +3,7 @@ import torch from torch.distributed import ProcessGroup +from contextlib import contextmanager from .parallel_state import (get_cpu_world_group, get_tensor_model_parallel_group, @@ -10,6 +11,19 @@ get_tensor_model_parallel_world_size) +@contextmanager +def use_pynccl_allreduce(): + from vllm.distributed.device_communicators import custom_all_reduce + if not custom_all_reduce.is_initialized(): + from vllm.distributed.parallel_state import _TP_PYNCCL_COMMUNICATOR + assert _TP_PYNCCL_COMMUNICATOR is not None + with _TP_PYNCCL_COMMUNICATOR.enable( + stream=torch.cuda.current_stream()): + yield + else: + yield + + def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across model parallel group. diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index e0c3e18187dbc..2cefc0c417eb1 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -134,7 +134,7 @@ def enable(self, enable = self.available if stream is None: - stream = torch.cuda.current_stream() + stream = self.stream old_disable = self.disabled old_stream = self.stream diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6a7cbea0b4f5f..890238d70b82c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -3,7 +3,6 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" -import contextlib from typing import List, Optional import torch @@ -323,29 +322,3 @@ def destroy_model_parallel(): _PIPELINE_MODEL_PARALLEL_GROUP = None global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = None - - -@contextlib.contextmanager -def with_pynccl_for_all_reduce(): - # We use pynccl for all reduce when using CUDA graph, because - # torch.distributed is not well supported by CUDA graph. - """use pynccl instead of torch.distributed for all reduce""" - tp_size = get_tensor_model_parallel_world_size() - if tp_size == 1: - # No-op. - # NOTE(woosuk): We don't initialize pynccl when tp_size is 1. - yield - else: - global _TP_PYNCCL_COMMUNICATOR - assert _TP_PYNCCL_COMMUNICATOR is not None - old = _TP_PYNCCL_COMMUNICATOR.disabled - _TP_PYNCCL_COMMUNICATOR.disabled = False - - stream = torch.cuda.current_stream() - old_stream = _TP_PYNCCL_COMMUNICATOR.stream - - _TP_PYNCCL_COMMUNICATOR.stream = stream - yield - - _TP_PYNCCL_COMMUNICATOR.stream = old_stream - _TP_PYNCCL_COMMUNICATOR.disabled = old diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f647e994d1fc8..5b53703c7002f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,4 +1,3 @@ -import contextlib import time from enum import IntEnum from typing import Dict, List, NamedTuple, Optional, Set, Tuple @@ -11,8 +10,9 @@ get_attn_backend) from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce +from vllm.distributed import broadcast_tensor_dict from vllm.distributed.device_communicators import custom_all_reduce +from vllm.distributed.communication_op import use_pynccl_allreduce from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -982,7 +982,7 @@ def capture( # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with _maybe_pynccl(): + with use_pynccl_allreduce(): self.model( input_ids, positions, @@ -997,7 +997,7 @@ def capture( # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 - with _maybe_pynccl(): + with use_pynccl_allreduce(): hidden_states = self.model( input_ids, positions, @@ -1049,15 +1049,6 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) -@contextlib.contextmanager -def _maybe_pynccl(): - if not custom_all_reduce.is_initialized(): - with with_pynccl_for_all_reduce(): - yield - else: - yield - - def _get_graph_batch_size(batch_size: int) -> int: """Returns the padded batch size given actual batch size. From 7b550263bca94900af0e34d956d066323b68fe7e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 23:54:40 -0700 Subject: [PATCH 29/43] fix isort --- vllm/distributed/communication_op.py | 2 +- vllm/worker/model_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 2ee188654ab59..e5efd09a0f329 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,9 +1,9 @@ from collections import namedtuple +from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch.distributed import ProcessGroup -from contextlib import contextmanager from .parallel_state import (get_cpu_world_group, get_tensor_model_parallel_group, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5b53703c7002f..37dd058885a40 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -11,8 +11,8 @@ from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict -from vllm.distributed.device_communicators import custom_all_reduce from vllm.distributed.communication_op import use_pynccl_allreduce +from vllm.distributed.device_communicators import custom_all_reduce from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest From ee734b13a9bda9b6be1c904bcf439735148cbacf Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 4 May 2024 00:11:41 -0700 Subject: [PATCH 30/43] fix stream attribute --- vllm/distributed/device_communicators/pynccl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 2cefc0c417eb1..1ccb58cdb330a 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -47,6 +47,7 @@ def __init__( if self.world_size == 1: self.available = False self.disabled = True + self.stream = None return try: self.nccl = NCCLLibrary(library_path) @@ -55,6 +56,7 @@ def __init__( # e.g. in a non-GPU environment self.available = False self.disabled = True + self.stream = None return self.available = True From 0516956ce52771a5d1495c171463d45698d3a2b5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 May 2024 23:23:45 -0700 Subject: [PATCH 31/43] fix import --- tests/distributed/test_pynccl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index e51676bcccad2..9ed1092b25418 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -1,4 +1,5 @@ import multiprocessing +import os import pytest import torch @@ -40,7 +41,6 @@ def worker_fn_wrapper(fn): # and update the environment variables in the function def wrapped_fn(env): update_environment_variables(env) - import os local_rank = os.environ['LOCAL_RANK'] device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) From 9f63bf8a773f847837c2e2c7992a19cfcb1c3fef Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 May 2024 23:26:57 -0700 Subject: [PATCH 32/43] rename to PyNcclCommunicator and pynccl_comm --- tests/distributed/test_pynccl.py | 42 ++++++++++--------- .../device_communicators/pynccl.py | 6 +-- vllm/distributed/parallel_state.py | 4 +- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 9ed1092b25418..5ba0ea690abdb 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -6,7 +6,7 @@ from vllm.distributed.communication_op import ( tensor_model_parallel_all_reduce, use_pynccl_allreduce) -from vllm.distributed.device_communicators.pynccl import NCCLCommunicator +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, init_distributed_environment) @@ -52,12 +52,13 @@ def wrapped_fn(env): @worker_fn_wrapper def worker_fn(): - comm = NCCLCommunicator() - tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) - with comm.enable(): - comm.all_reduce(tensor) + pynccl_comm = PyNcclCommunicator() + tensor = torch.ones(16, 1024, 1024, + dtype=torch.float32).cuda(pynccl_comm.rank) + with pynccl_comm.enable(): + pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() - assert result == comm.world_size + assert result == pynccl_comm.world_size @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -74,17 +75,17 @@ def multiple_tp_worker_fn(): torch.distributed.new_group(ranks=[2, 3], backend="gloo") ] group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] - comm = NCCLCommunicator(group=group, device=device) + pynccl_comm = PyNcclCommunicator(group=group, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with comm.enable(): + with pynccl_comm.enable(): # two groups can communicate independently if torch.distributed.get_rank() in [0, 1]: - comm.all_reduce(tensor) - comm.all_reduce(tensor) + pynccl_comm.all_reduce(tensor) + pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() assert result == 4 else: - comm.all_reduce(tensor) + pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() assert result == 2 @@ -93,7 +94,7 @@ def multiple_tp_worker_fn(): reason="Need at least 4 GPUs to run the test.") def test_pynccl_multiple_tp(): # this tests pynccl for multiple tp groups, in a standalone way - # i.e. call `comm.all_reduce` directly + # i.e. call `pynccl_comm.all_reduce` directly distributed_run(multiple_tp_worker_fn, 4) @@ -127,19 +128,20 @@ def test_pynccl_multiple_tp_with_vllm(): def worker_fn_with_cudagraph(): with torch.no_grad(): graph = torch.cuda.CUDAGraph() - comm = NCCLCommunicator() + pynccl_comm = PyNcclCommunicator() # run something in the default stream to initialize torch engine - a = torch.ones((4, 4), device=f'cuda:{comm.rank}') + a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph(graph, stream=comm.stream), comm.enable(): + with torch.cuda.graph(graph, + stream=pynccl_comm.stream), pynccl_comm.enable(): # operation during the graph capture is recorded but not executed # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa - comm.all_reduce(a) - comm.stream.synchronize() - assert a.mean().cpu().item() == comm.world_size**0 + pynccl_comm.all_reduce(a) + pynccl_comm.stream.synchronize() + assert a.mean().cpu().item() == pynccl_comm.world_size**0 graph.replay() - comm.stream.synchronize() - assert a.mean().cpu().item() == comm.world_size**1 + pynccl_comm.stream.synchronize() + assert a.mean().cpu().item() == pynccl_comm.world_size**1 @pytest.mark.skipif(torch.cuda.device_count() < 2, diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 1ccb58cdb330a..0948b03ea893d 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -15,7 +15,7 @@ logger = init_logger(__name__) -class NCCLCommunicator: +class PyNcclCommunicator: def __init__( self, @@ -27,7 +27,7 @@ def __init__( Args: group: the process group to work on. If None, it will use the default process group. - device: the device to bind the NCCLCommunicator to. If None, + device: the device to bind the PyNcclCommunicator to. If None, it will be bind to f"cuda:{local_rank}". library_path: the path to the NCCL library. If None, it will use the default library path. @@ -37,7 +37,7 @@ def __init__( assert dist.is_initialized() group = get_cpu_world_group() if group is None else group assert dist.get_backend(group) != dist.Backend.NCCL, ( - "NCCLCommunicator should be attached to a non-NCCL group.") + "PyNcclCommunicator should be attached to a non-NCCL group.") self.group = group # note: this rank is the rank in the group self.rank = dist.get_rank(group) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 890238d70b82c..56e2745bfb7d5 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -157,8 +157,8 @@ def initialize_model_parallel( _TP_DEVICE_GROUP = group _TP_CPU_GROUP = cpu_group - from vllm.distributed.device_communicators.pynccl import NCCLCommunicator - _TP_PYNCCL_COMMUNICATOR = NCCLCommunicator( + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( group=_TP_CPU_GROUP, device=_LOCAL_RANK, ) From e9aa766c278166db7d141117c489d498a54af29d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 May 2024 23:40:26 -0700 Subject: [PATCH 33/43] rename use_pynccl_allreduce --- tests/distributed/test_pynccl.py | 4 ++-- vllm/distributed/communication_op.py | 12 +++++++++++- vllm/worker/model_runner.py | 6 +++--- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 5ba0ea690abdb..8d24aa0b663bd 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -5,7 +5,7 @@ import torch from vllm.distributed.communication_op import ( - tensor_model_parallel_all_reduce, use_pynccl_allreduce) + graph_capture_mode, tensor_model_parallel_all_reduce) from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, @@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") ensure_model_parallel_initialized(2, 2) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with use_pynccl_allreduce(): + with graph_capture_mode(): # two tp groups can communicate independently if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index e5efd09a0f329..3163cdeac743a 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -12,7 +12,17 @@ @contextmanager -def use_pynccl_allreduce(): +def graph_capture_mode(): + # In graph capture, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the tensor size + # is too large, it will fallback to the next available option. from vllm.distributed.device_communicators import custom_all_reduce if not custom_all_reduce.is_initialized(): from vllm.distributed.parallel_state import _TP_PYNCCL_COMMUNICATOR diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 37dd058885a40..30aa10c12d396 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -11,7 +11,7 @@ from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict -from vllm.distributed.communication_op import use_pynccl_allreduce +from vllm.distributed.communication_op import graph_capture_mode from vllm.distributed.device_communicators import custom_all_reduce from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping @@ -982,7 +982,7 @@ def capture( # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with use_pynccl_allreduce(): + with graph_capture_mode(): self.model( input_ids, positions, @@ -997,7 +997,7 @@ def capture( # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 - with use_pynccl_allreduce(): + with graph_capture_mode(): hidden_states = self.model( input_ids, positions, From 0f643017b0488c7901022b4c786d54a91ad5f082 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 May 2024 23:45:46 -0700 Subject: [PATCH 34/43] fix lint --- tests/distributed/test_pynccl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 8d24aa0b663bd..a0861185cc129 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -4,8 +4,8 @@ import pytest import torch -from vllm.distributed.communication_op import ( - graph_capture_mode, tensor_model_parallel_all_reduce) +from vllm.distributed.communication_op import (graph_capture_mode, + tensor_model_parallel_all_reduce) from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, From a64962ee39d4b1294817782b74a1606966855031 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 May 2024 23:46:14 -0700 Subject: [PATCH 35/43] fix lint --- tests/distributed/test_pynccl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index a0861185cc129..b558d8bdd9d8d 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -4,8 +4,8 @@ import pytest import torch -from vllm.distributed.communication_op import (graph_capture_mode, - tensor_model_parallel_all_reduce) +from vllm.distributed.communication_op import graph_capture_mode +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, From d2f83ba80ae1058d42865a1f1ec4f12051dc1de5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 May 2024 23:50:40 -0700 Subject: [PATCH 36/43] fix lint --- tests/distributed/test_pynccl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index b558d8bdd9d8d..c2b744f42acd4 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -4,8 +4,8 @@ import pytest import torch -from vllm.distributed.communication_op import graph_capture_mode -from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.distributed.communication_op import ( # noqa + graph_capture_mode, tensor_model_parallel_all_reduce) from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, From 12f309b3b823733d33bdc24fa267fe200c5bfd63 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 May 2024 23:53:03 -0700 Subject: [PATCH 37/43] fix dependency on custom_all_reduce --- vllm/distributed/communication_op.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 3163cdeac743a..18e0fb1fca0d5 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -23,14 +23,9 @@ def graph_capture_mode(): # # Note that custom allreduce will have a runtime check, if the tensor size # is too large, it will fallback to the next available option. - from vllm.distributed.device_communicators import custom_all_reduce - if not custom_all_reduce.is_initialized(): - from vllm.distributed.parallel_state import _TP_PYNCCL_COMMUNICATOR - assert _TP_PYNCCL_COMMUNICATOR is not None - with _TP_PYNCCL_COMMUNICATOR.enable( - stream=torch.cuda.current_stream()): - yield - else: + from vllm.distributed.parallel_state import _TP_PYNCCL_COMMUNICATOR + assert _TP_PYNCCL_COMMUNICATOR is not None + with _TP_PYNCCL_COMMUNICATOR.enable(stream=torch.cuda.current_stream()): yield From 68e448c8039c3d6ee9d939ebd88d6e90d1a0c81f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 May 2024 23:55:33 -0700 Subject: [PATCH 38/43] fix lint --- vllm/distributed/communication_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 18e0fb1fca0d5..d07adbe04a108 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -51,8 +51,8 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: out = custom_all_reduce(input_) if out is not None: return out - if _TP_PYNCCL_COMMUNICATOR is not None and \ - not _TP_PYNCCL_COMMUNICATOR.disabled: + if (_TP_PYNCCL_COMMUNICATOR is not None + and not _TP_PYNCCL_COMMUNICATOR.disabled): _TP_PYNCCL_COMMUNICATOR.all_reduce(input_) else: torch.distributed.all_reduce(input_, From ad6f84037ba9e4b490a7e5289176a14a0513a90a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 May 2024 23:59:47 -0700 Subject: [PATCH 39/43] use _PP_DEVICE_GROUP --- vllm/distributed/parallel_state.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 56e2745bfb7d5..e8af0f1eae649 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -18,7 +18,7 @@ _TP_CPU_GROUP: Optional[ProcessGroup] = None _TP_PYNCCL_COMMUNICATOR = None # Pipeline model parallel group that the current rank belongs to. -_PIPELINE_MODEL_PARALLEL_GROUP: Optional[ProcessGroup] = None +_PP_DEVICE_GROUP: Optional[ProcessGroup] = None # when people blindly call `torch.distributed.all_reduce` etc, # it will use this group. It is initialized with the `backend` @@ -164,15 +164,15 @@ def initialize_model_parallel( ) # Build the pipeline model-parallel groups. - global _PIPELINE_MODEL_PARALLEL_GROUP + global _PP_DEVICE_GROUP global _PIPELINE_GLOBAL_RANKS - assert _PIPELINE_MODEL_PARALLEL_GROUP is None, ( + assert _PP_DEVICE_GROUP is None, ( "pipeline model parallel group is already initialized") for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group = torch.distributed.new_group(ranks, backend=backend) if rank in ranks: - _PIPELINE_MODEL_PARALLEL_GROUP = group + _PP_DEVICE_GROUP = group _PIPELINE_GLOBAL_RANKS = ranks @@ -207,7 +207,7 @@ def ensure_model_parallel_initialized( def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" return (_TP_DEVICE_GROUP is not None - and _PIPELINE_MODEL_PARALLEL_GROUP is not None) + and _PP_DEVICE_GROUP is not None) def get_cpu_world_group(): @@ -232,9 +232,9 @@ def get_tensor_model_parallel_cpu_group(): def get_pipeline_model_parallel_group(): """Get the pipeline model parallel group the caller rank belongs to.""" - assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, ( + assert _PP_DEVICE_GROUP is not None, ( "pipeline model parallel group is not initialized") - return _PIPELINE_MODEL_PARALLEL_GROUP + return _PP_DEVICE_GROUP def get_tensor_model_parallel_world_size(): @@ -316,9 +316,9 @@ def destroy_model_parallel(): global _TP_PYNCCL_COMMUNICATOR _TP_PYNCCL_COMMUNICATOR = None - global _PIPELINE_MODEL_PARALLEL_GROUP - if _PIPELINE_MODEL_PARALLEL_GROUP: - torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP) - _PIPELINE_MODEL_PARALLEL_GROUP = None + global _PP_DEVICE_GROUP + if _PP_DEVICE_GROUP: + torch.distributed.destroy_process_group(_PP_DEVICE_GROUP) + _PP_DEVICE_GROUP = None global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = None From e2153b21197d09c162213e822a817aef7c93835d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 9 May 2024 00:00:53 -0700 Subject: [PATCH 40/43] use _PP_GLOBAL_RANKS --- vllm/distributed/parallel_state.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e8af0f1eae649..a1720c4743878 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -42,7 +42,7 @@ # A list of global ranks for each pipeline group to ease calculation of the # source rank when broadcasting from the first or last pipeline stage. -_PIPELINE_GLOBAL_RANKS: Optional[List[int]] = None +_PP_GLOBAL_RANKS: Optional[List[int]] = None _LOCAL_RANK = -1 @@ -165,7 +165,7 @@ def initialize_model_parallel( # Build the pipeline model-parallel groups. global _PP_DEVICE_GROUP - global _PIPELINE_GLOBAL_RANKS + global _PP_GLOBAL_RANKS assert _PP_DEVICE_GROUP is None, ( "pipeline model parallel group is already initialized") for i in range(num_pipeline_model_parallel_groups): @@ -173,7 +173,7 @@ def initialize_model_parallel( group = torch.distributed.new_group(ranks, backend=backend) if rank in ranks: _PP_DEVICE_GROUP = group - _PIPELINE_GLOBAL_RANKS = ranks + _PP_GLOBAL_RANKS = ranks def ensure_model_parallel_initialized( @@ -271,36 +271,36 @@ def get_tensor_model_parallel_src_rank(): def get_pipeline_model_parallel_first_rank(): """Return the global rank of the first process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") - return _PIPELINE_GLOBAL_RANKS[0] + return _PP_GLOBAL_RANKS[0] def get_pipeline_model_parallel_last_rank(): """Return the global rank of the last process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") last_rank_local = get_pipeline_model_parallel_world_size() - 1 - return _PIPELINE_GLOBAL_RANKS[last_rank_local] + return _PP_GLOBAL_RANKS[last_rank_local] def get_pipeline_model_parallel_next_rank(): """Return the global rank that follows the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + return _PP_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] def get_pipeline_model_parallel_prev_rank(): """Return the global rank that precedes the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + return _PP_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] def destroy_model_parallel(): @@ -320,5 +320,5 @@ def destroy_model_parallel(): if _PP_DEVICE_GROUP: torch.distributed.destroy_process_group(_PP_DEVICE_GROUP) _PP_DEVICE_GROUP = None - global _PIPELINE_GLOBAL_RANKS - _PIPELINE_GLOBAL_RANKS = None + global _PP_GLOBAL_RANKS + _PP_GLOBAL_RANKS = None From 80aca9410ec7265e6b783e92d8570167a84d5e0f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 9 May 2024 00:01:05 -0700 Subject: [PATCH 41/43] fix lint --- vllm/distributed/parallel_state.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a1720c4743878..bd48b0619e922 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -206,8 +206,7 @@ def ensure_model_parallel_initialized( def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return (_TP_DEVICE_GROUP is not None - and _PP_DEVICE_GROUP is not None) + return (_TP_DEVICE_GROUP is not None and _PP_DEVICE_GROUP is not None) def get_cpu_world_group(): From c1b1cdbce0934fc5d366cad103a9a7877d86daea Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 9 May 2024 00:13:49 -0700 Subject: [PATCH 42/43] use change_state rather than enable --- tests/distributed/test_pynccl.py | 9 +++++---- vllm/distributed/communication_op.py | 3 ++- vllm/distributed/device_communicators/pynccl.py | 10 +++++----- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index c2b744f42acd4..b3e30a0434423 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -55,7 +55,7 @@ def worker_fn(): pynccl_comm = PyNcclCommunicator() tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) - with pynccl_comm.enable(): + with pynccl_comm.change_state(enable=True): pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() assert result == pynccl_comm.world_size @@ -77,7 +77,7 @@ def multiple_tp_worker_fn(): group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] pynccl_comm = PyNcclCommunicator(group=group, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with pynccl_comm.enable(): + with pynccl_comm.change_state(enable=True): # two groups can communicate independently if torch.distributed.get_rank() in [0, 1]: pynccl_comm.all_reduce(tensor) @@ -132,8 +132,9 @@ def worker_fn_with_cudagraph(): # run something in the default stream to initialize torch engine a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph(graph, - stream=pynccl_comm.stream), pynccl_comm.enable(): + with torch.cuda.graph( + graph, stream=pynccl_comm.stream), pynccl_comm.change_state( + enable=True): # operation during the graph capture is recorded but not executed # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa pynccl_comm.all_reduce(a) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index d07adbe04a108..76d430fd0df6b 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -25,7 +25,8 @@ def graph_capture_mode(): # is too large, it will fallback to the next available option. from vllm.distributed.parallel_state import _TP_PYNCCL_COMMUNICATOR assert _TP_PYNCCL_COMMUNICATOR is not None - with _TP_PYNCCL_COMMUNICATOR.enable(stream=torch.cuda.current_stream()): + with _TP_PYNCCL_COMMUNICATOR.change_state( + enable=True, stream=torch.cuda.current_stream()): yield diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 0948b03ea893d..168d4cc2df8a6 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -100,7 +100,7 @@ def __init__( self.stream.synchronize() # by default it is disabled, e.g. in profiling models and prefill phase. - # to use it, we have to use under `with obj.enable()`, usually + # to use it, use under `with obj.change_state(enable=True)`, usually # when we are using CUDA graph. self.disabled = True @@ -125,11 +125,11 @@ def all_reduce(self, cudaStream_t(stream.cuda_stream)) @contextmanager - def enable(self, - enable: Optional[bool] = None, - stream: Optional[torch.cuda.Stream] = None): + def change_state(self, + enable: Optional[bool] = None, + stream: Optional[torch.cuda.Stream] = None): """ - A context manager to enable or disable the communicator. + A context manager to change the state of the communicator. """ if enable is None: # guess a default value when not specified From 70a7e2679b477e02c6559c4c47f9368d7c728ad3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 9 May 2024 16:05:09 -0700 Subject: [PATCH 43/43] add get_tp_pynccl_communicator --- vllm/distributed/communication_op.py | 18 +++++++++--------- vllm/distributed/parallel_state.py | 5 +++++ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index f8afaee9177c2..32ab5694e5390 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -8,7 +8,8 @@ from .parallel_state import (get_cpu_world_group, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, + get_tp_pynccl_communicator) @contextmanager @@ -23,10 +24,10 @@ def graph_capture_mode(): # # Note that custom allreduce will have a runtime check, if the tensor size # is too large, it will fallback to the next available option. - from vllm.distributed.parallel_state import _TP_PYNCCL_COMMUNICATOR - assert _TP_PYNCCL_COMMUNICATOR is not None - with _TP_PYNCCL_COMMUNICATOR.change_state( - enable=True, stream=torch.cuda.current_stream()): + pynccl_comm = get_tp_pynccl_communicator() + assert pynccl_comm is not None + with pynccl_comm.change_state(enable=True, + stream=torch.cuda.current_stream()): yield @@ -44,7 +45,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: """ from vllm.distributed.device_communicators.custom_all_reduce import ( custom_all_reduce) - from vllm.distributed.parallel_state import _TP_PYNCCL_COMMUNICATOR # Bypass the function if we are using only 1 GPU. if get_tensor_model_parallel_world_size() == 1: @@ -52,9 +52,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: out = custom_all_reduce(input_) if out is not None: return out - if (_TP_PYNCCL_COMMUNICATOR is not None - and not _TP_PYNCCL_COMMUNICATOR.disabled): - _TP_PYNCCL_COMMUNICATOR.all_reduce(input_) + pynccl_comm = get_tp_pynccl_communicator() + if (pynccl_comm is not None and not pynccl_comm.disabled): + pynccl_comm.all_reduce(input_) else: torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index bd48b0619e922..5075da11bb1b8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -47,6 +47,11 @@ _LOCAL_RANK = -1 +def get_tp_pynccl_communicator(): + global _TP_PYNCCL_COMMUNICATOR + return _TP_PYNCCL_COMMUNICATOR + + def get_local_rank(): global _LOCAL_RANK return _LOCAL_RANK